어텐션 가중치를 시각화하기 위한 함수를 정의합니다.
코드 10-39 어텐션 가중치 시각화 함수
def plot_attention(attention, sentence, predicted_sentence):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
코드 10-40 번역을 위한 함수 정의 및 번역 문장 입력 함수
def translate(sentence):
result, sentence, attention_plot = evaluate(sentence)
print('Input: %s' % (sentence))
print('Predicted translation: {}'.format(result))
attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
plot_attention(attention_plot, sentence.split(' '), result.split(' ')) ------ 어텐션 가중치 매핑
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir)) ------ ①
translate(u'esta es mi vida.') ------ 스페인어를 영어로 번역