7.6.2 GRU 셀 구현
필요한 라이브러리 및 데이터 호출은 RNN 셀에서의 수행과 동일하므로 생략하며, GRU 셀을 이용한 네트워크 코드를 살펴보겠습니다.
코드 7-18 네트워크 생성
class GRU_Build(tf.keras.Model):
def __init__(self, units):
super(GRU_Build, self).__init__()
self.state0 = [tf.zeros([batch_size, units])]
self.state1 = [tf.zeros([batch_size, units])]
self.embedding = tf.keras.layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
self.RNNCell0 = tf.keras.layers.GRUCell(units, dropout=0.5) ------ ①
self.RNNCell1 = tf.keras.layers.GRUCell(units, dropout=0.5)
self.outlayer = tf.keras.layers.Dense(1)
def call(self, inputs, training=None):
x = inputs
x = self.embedding(x)
state0 = self.state0 ------ 초기 상태는 모두 0으로 설정
state1 = self.state1
for word in tf.unstack(x, axis=1):
out0, state0 = self.RNNCell0(word, state0, training)
out1, state1 = self.RNNCell1(out0, state1, training)
x = self.outlayer(out1)
prob = tf.sigmoid(x)
return prob