다음은 생성된 이미지를 시각적으로 출력하는 함수입니다.

    코드 13-24 생성된 이미지 출력 함수

    def generate_images(model, epoch, test_input):
        predictions = model(test_input, training=False)
        fig = plt.figure(figsize=(4,4))
    
        for i in range(predictions.shape[0]):
            plt.subplot(4, 4, i+1)
            plt.imshow(predictions[i,:,:,0] * 127.5 + 127.5, cmap='rainbow')

    이제 모든 준비가 완료되었으므로 모델을 훈련시키는 함수를 정의합니다.

    코드 13-25 모델 훈련 함수

    def train_GAN(dataset, epochs):
        for epoch in range(epochs):
            start = time.time() ------ 매 에포크마다 시작 시간 표시
            for image_batch in dataset:
                train_step(image_batch) ------ train_step() 함수를 호출하여 모델 훈련
            if epoch % 10 == 0:
                generate_images(generator, epoch+1, seed)
            print('에포크 {} 은/는 {} 초'.format(epoch+1, time.time()-start)) ------ 매 에포크마다 걸린 시간을 표시
        generate_images(generator, epochs, seed)
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.