다음은 생성된 이미지를 시각적으로 출력하는 함수입니다.
코드 13-24 생성된 이미지 출력 함수
def generate_images(model, epoch, test_input):
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i,:,:,0] * 127.5 + 127.5, cmap='rainbow')
이제 모든 준비가 완료되었으므로 모델을 훈련시키는 함수를 정의합니다.
코드 13-25 모델 훈련 함수
def train_GAN(dataset, epochs):
for epoch in range(epochs):
start = time.time() ------ 매 에포크마다 시작 시간 표시
for image_batch in dataset:
train_step(image_batch) ------ train_step() 함수를 호출하여 모델 훈련
if epoch % 10 == 0:
generate_images(generator, epoch+1, seed)
print('에포크 {} 은/는 {} 초'.format(epoch+1, time.time()-start)) ------ 매 에포크마다 걸린 시간을 표시
generate_images(generator, epochs, seed)