① LSTM 셀에서는 4를 곱했지만 GRU 셀에서는 세 개의 게이트가 사용되므로 3을 곱합니다. 엄밀히 게이트는 두 개(망각, 입력 게이트)이지만 탄젠트 활성화 함수가 적용되는 부분을 ‘새로운 게이트(newgate)’로 정의하여 총 3을 곱합니다.
개별적인 GRU 셀의 네트워크가 구성되었기 때문에 전반적인 네트워크에 대해 살펴봅시다.
코드 7-58 전반적인 네트워크 구조
class GRUModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, bias=True):
super(GRUModel, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.gru_cell = GRUCell(input_dim, hidden_dim, layer_dim) ------ 앞에서 정의한 GRUCell 함수를 불러옵니다.
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
if torch.cuda.is_available():
h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).cuda())
else:
h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
outs = []
hn = h0[0,:,:] ------ LSTM 셀에서는 셀 상태에 대해서도 정의했었지만 GRU 셀에서는 셀은 사용되지 않습니다.
for seq in range(x.size(1)):
hn = self.gru_cell(x[:,seq,:], hn)
outs.append(hn)
out = outs[-1].squeeze()
out = self.fc(out)
return out