더북(TheBook)

코드 7-13 워드 임베딩 및 RNN 셀 정의

class RNNCell_Encoder(nn.Module):
    def __init__(self, input_dim, hidden_size):
        super(RNNCell_Encoder, self).__init__()
        self.rnn = nn.RNNCell(input_dim, hidden_size) ------ ①

    def forward(self, inputs): ------ inputs는 입력 시퀀스로 (시퀀스 길이, 배치, 임베딩(seq,batch, embedding))의 형태를 갖습니다.
        bz = inputs.shape[1] ------ 배치를 가져옵니다.
        ht = torch.zeros((bz, hidden_size)).to(device) ------ 배치와 은닉층 뉴런의 크기를 0으로 초기화
        for word in inputs:
            ht = self.rnn(word, ht) ------ ②
        return ht

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.em = nn.Embedding(len(TEXT.vocab.stoi), embeding_dim) ------ ③
        self.rnn = RNNCell_Encoder(embeding_dim, hidden_size)
        self.fc1 = nn.Linear(hidden_size, 256)
        self.fc2 = nn.Linear(256, 3)

    def forward(self, x):
        x = self.em(x)
        x = self.rnn(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.