LSTM 셀에 대한 네트워크를 구축합니다. 모델의 전반적인 네트워크가 아닌 LSTM 셀에 집중한 네트워크입니다.
코드 7-34 LSTM 셀 네트워크 구축
class LSTMCell(nn.Module): ------ LSTM 셀에 대한 더 자세한 설명을 원한다면 http://www.bioinf.jku.at/publications/older/2604.pdf 논문을 참고하세요.
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) ------ ①
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) ------ ①′
self.reset_parameters()
def reset_parameters(self): ------ 모델의 파라미터 초기화
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std) ------ ②
def forward(self, x, hidden):
hx, cx = hidden
x = x.view(-1, x.size(1))
gates = self.x2h(x) + self.h2h(hx) ------ ①″
gates = gates.squeeze() ------ ③
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ------ ①‴
ingate = F.sigmoid(ingate) ------ 입력 게이트에 시그모이드 활성화 함수 적용
forgetgate = F.sigmoid(forgetgate) ------ 망각 게이트에 시그모이드 활성화 함수 적용
cellgate = F.tanh(cellgate) ------ 셀 게이트에 탄젠트 활성화 함수 적용
outgate = F.sigmoid(outgate) ------ 출력 게이트에 시그모이드 활성화 함수 적용
cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) ------ ④
hy = torch.mul(outgate, F.tanh(cy)) ------ ④′
return(hy, cy)