EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
model_history = model.fit(train_batches, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=test_batches,
callbacks=[DisplayCallback()]) # ④
① create_mask 함수는 모델의 예측 결과에서 가장 높은 확률을 가진 클래스를 선택하여 최종 예측 마스크를 생성합니다. tf.math.argmax 함수를 사용하여 확률이 가장 높은 클래스의 인덱스를 선택하고, 이를 새로운 차원으로 확장하여 반환합니다.
② show_predictions 함수는 주어진 데이터 세트에서 모델의 예측 결과를 시각화합니다. dataset 매개변수가 주어지면, 해당 데이터 세트에서 샘플을 가져와 모델의 예측 결과를 표시합니다. dataset이 없는 경우, 샘플 이미지에 대한 모델의 예측을 표시합니다.