코드 7-13 워드 임베딩 및 RNN 셀 정의
class RNNCell_Encoder(nn.Module):
def __init__(self, input_dim, hidden_size):
super(RNNCell_Encoder, self).__init__()
self.rnn = nn.RNNCell(input_dim, hidden_size) ------ ①
def forward(self, inputs): ------ inputs는 입력 시퀀스로 (시퀀스 길이, 배치, 임베딩(seq,batch, embedding))의 형태를 갖습니다.
bz = inputs.shape[1] ------ 배치를 가져옵니다.
ht = torch.zeros((bz, hidden_size)).to(device) ------ 배치와 은닉층 뉴런의 크기를 0으로 초기화
for word in inputs:
ht = self.rnn(word, ht) ------ ②
return ht
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.em = nn.Embedding(len(TEXT.vocab.stoi), embeding_dim) ------ ③
self.rnn = RNNCell_Encoder(embeding_dim, hidden_size)
self.fc1 = nn.Linear(hidden_size, 256)
self.fc2 = nn.Linear(256, 3)
def forward(self, x):
x = self.em(x)
x = self.rnn(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x