7.4.2 RNN 계층 구현
필요한 라이브러리 및 데이터 호출은 RNN 셀에서의 수행과 동일하므로 생략합니다.
바로 RNN 계층 네트워크(신경망)부터 생성하겠습니다.
코드 7-9 네트워크(신경망) 구축
class RNN_Build(tf.keras.Model):
def __init__(self, units):
super(RNN_Build, self).__init__()
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
self.rnn = tf.keras.Sequential([
tf.keras.layers.SimpleRNN(units, dropout=0.5, return_sequences=True), ------ ①
tf.keras.layers.SimpleRNN(units, dropout=0.5)
])
self.outlayer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
x = inputs
x = self.embedding(x)
x = self.rnn(x)
x = self.outlayer(x)
prob = tf.sigmoid(x)
return prob