이미지가 들어온 후 모델이 어떤 흐름으로 진행되는지 확인할 수 있습니다.
이제 모델이 학습하며 모델의 예측 결과를 시각화하는 함수를 정의하고, 모델을 학습시키는 과정을 진행해보겠습니다.
def create_mask(pred_mask): # ①
pred_mask = tf.math.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1): # ②
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0], mask[0], create_mask(pred_mask)])
else:
display([sample_image, sample_mask,
create_mask(model.predict(sample_image[tf.newaxis, ...]))])
class DisplayCallback(tf.keras.callbacks.Callback): # ③
def on_epoch_end(self, epoch, logs=None):
clear_output(wait=True)
show_predictions()
print ('\nSample Prediction after epoch {}\n'.format(epoch+1))