더북(TheBook)

LSTM 셀에 대한 네트워크를 구축합니다. 모델의 전반적인 네트워크가 아닌 LSTM 셀에 집중한 네트워크입니다.

코드 7-34 LSTM 셀 네트워크 구축

class LSTMCell(nn.Module): ------ LSTM 셀에 대한 더 자세한 설명을 원한다면 http://www.bioinf.jku.at/publications/older/2604.pdf 논문을 참고하세요.
    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) ------ ①
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) ------ ①′
        self.reset_parameters()

    def reset_parameters(self): ------ 모델의 파라미터 초기화
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std) ------ ②

    def forward(self, x, hidden):
        hx, cx = hidden
        x = x.view(-1, x.size(1))

        gates = self.x2h(x) + self.h2h(hx) ------ ①″
        gates = gates.squeeze() ------ ③
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) ------ ①‴

        ingate = F.sigmoid(ingate) ------ 입력 게이트에 시그모이드 활성화 함수 적용
        forgetgate = F.sigmoid(forgetgate) ------ 망각 게이트에 시그모이드 활성화 함수 적용
        cellgate = F.tanh(cellgate) ------ 셀 게이트에 탄젠트 활성화 함수 적용
        outgate = F.sigmoid(outgate) ------ 출력 게이트에 시그모이드 활성화 함수 적용

        cy = torch.mul(cx, forgetgate) + torch.mul(ingate, cellgate) ------ ④
        hy = torch.mul(outgate, F.tanh(cy)) ------ ④′
        return(hy, cy)
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.