코드 10-31 seq2seq 네트워크
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device, MAX_LENGTH=MAX_LENGTH):
super().__init__()
self.encoder = encoder ------ 인코더 초기화
self.decoder = decoder ------ 디코더 초기화
self.device = device
def forward(self, input_lang, output_lang, teacher_forcing_ratio=0.5):
input_length = input_lang.size(0) ------ 입력 문자 길이(문장의 단어 수)
batch_size = output_lang.shape[1]
target_length = output_lang.shape[0]
vocab_size = self.decoder.output_dim
outputs = torch.zeros(target_length, batch_size, vocab_size).to(self.device) ------ 예측된 출력을 저장하기 위한 변수 초기화
for i in range(input_length):
encoder_output, encoder_hidden = self.encoder(input_lang[i]) ------ 문장의 모든 단어를 인코딩
decoder_hidden = encoder_hidden.to(device) ------ 인코더의 은닉층을 디코더의 은닉층으로 사용
decoder_input = torch.tensor([SOS_token], device=device) ------ 첫 번째 예측 단어 앞에 토큰(SOS) 추가
for t in range(target_length): ------ 현재 단어에서 출력 단어를 예측
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
outputs[t] = decoder_output
teacher_force = random.random() < teacher_forcing_ratio ------ ①
topv, topi = decoder_output.topk(1)
input = (output_lang[t] if teacher_force else topi) ------ teacher_force를 활성화하면 목표 단어를 다음 입력으로 사용
if(teacher_force == False and input.item() == EOS_token): ------ teacher_force를 활성화하지 않으면 자체 예측 값을 다음 입력으로 사용
break
return outputs