더북(TheBook)

코드 10-31 seq2seq 네트워크

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, MAX_LENGTH=MAX_LENGTH):
        super().__init__()

        self.encoder = encoder ------ 인코더 초기화
        self.decoder = decoder ------ 디코더 초기화
        self.device = device

    def forward(self, input_lang, output_lang, teacher_forcing_ratio=0.5):
        input_length = input_lang.size(0) ------ 입력 문자 길이(문장의 단어 수)
        batch_size = output_lang.shape[1]
        target_length = output_lang.shape[0]
        vocab_size = self.decoder.output_dim
        outputs = torch.zeros(target_length, batch_size, vocab_size).to(self.device) ------ 예측된 출력을 저장하기 위한 변수 초기화

        for i in range(input_length):
            encoder_output, encoder_hidden = self.encoder(input_lang[i]) ------ 문장의 모든 단어를 인코딩
        decoder_hidden = encoder_hidden.to(device) ------ 인코더의 은닉층을 디코더의 은닉층으로 사용
        decoder_input = torch.tensor([SOS_token], device=device) ------ 첫 번째 예측 단어 앞에 토큰(SOS) 추가

        for t in range(target_length): ------ 현재 단어에서 출력 단어를 예측
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            outputs[t] = decoder_output
            teacher_force = random.random() < teacher_forcing_ratio ------ ①
            topv, topi = decoder_output.topk(1)
            input = (output_lang[t] if teacher_force else topi) ------ teacher_force를 활성화하면 목표 단어를 다음 입력으로 사용
            if(teacher_force == False and input.item() == EOS_token): ------ teacher_force를 활성화하지 않으면 자체 예측 값을 다음 입력으로 사용
                break
        return outputs
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.