코드 7-13 워드 임베딩 및 RNN 셀 정의

    class RNNCell_Encoder(nn.Module):
        def __init__(self, input_dim, hidden_size):
            super(RNNCell_Encoder, self).__init__()
            self.rnn = nn.RNNCell(input_dim, hidden_size) ------ ①
    
        def forward(self, inputs): ------ inputs는 입력 시퀀스로 (시퀀스 길이, 배치, 임베딩(seq,batch, embedding))의 형태를 갖습니다.
            bz = inputs.shape[1] ------ 배치를 가져옵니다.
            ht = torch.zeros((bz, hidden_size)).to(device) ------ 배치와 은닉층 뉴런의 크기를 0으로 초기화
            for word in inputs:
                ht = self.rnn(word, ht) ------ ②
            return ht
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.em = nn.Embedding(len(TEXT.vocab.stoi), embeding_dim) ------ ③
            self.rnn = RNNCell_Encoder(embeding_dim, hidden_size)
            self.fc1 = nn.Linear(hidden_size, 256)
            self.fc2 = nn.Linear(256, 3)
    
        def forward(self, x):
            x = self.em(x)
            x = self.rnn(x)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.