더북(TheBook)

7.5.3 LSTM 계층 구현

필요한 라이브러리 및 데이터 호출은 RNN 셀에서의 수행과 동일하므로 생략하며, LSTM 계층을 이용한 네트워크 코드를 살펴보겠습니다.

코드 7-15 네트워크 생성

class LSTM_Build(tf.keras.Model):

    def __init__(self, units):
        super(LSTM_Build, self).__init__()

        self.embedding = tf.keras.layers.Embedding(total_words, embedding_len,
                                                   input_length=max_review_len)
        self.rnn = tf.keras.Sequential([
            tf.keras.layers.LSTM(units, dropout=0.5, return_sequences=True,
                                 unroll=True), ------ ①
            tf.keras.layers.LSTM(units, dropout=0.5, unroll=True)
        ])
        self.outlayer = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None):
        x = inputs
        x = self.embedding(x)
        x = self.rnn(x)
        x = self.outlayer(x)
        prob = tf.sigmoid(x)

        return prob
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.