7.5.3 LSTM 계층 구현

    필요한 라이브러리 및 데이터 호출은 RNN 셀에서의 수행과 동일하므로 생략하며, LSTM 계층을 이용한 네트워크 코드를 살펴보겠습니다.

    코드 7-15 네트워크 생성

    class LSTM_Build(tf.keras.Model):
    
        def __init__(self, units):
            super(LSTM_Build, self).__init__()
    
            self.embedding = tf.keras.layers.Embedding(total_words, embedding_len,
                                                       input_length=max_review_len)
            self.rnn = tf.keras.Sequential([
                tf.keras.layers.LSTM(units, dropout=0.5, return_sequences=True,
                                     unroll=True), ------ ①
                tf.keras.layers.LSTM(units, dropout=0.5, unroll=True)
            ])
            self.outlayer = tf.keras.layers.Dense(1)
    
        def call(self, inputs, training=None):
            x = inputs
            x = self.embedding(x)
            x = self.rnn(x)
            x = self.outlayer(x)
            prob = tf.sigmoid(x)
    
            return prob
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.