이제 for 구문을 이용하여 200회 에포크만큼 모델을 학습시킵니다.
코드 13-32 모델 학습
generator.train() ------ 생성자를 학습 모드로 설정
discriminator.train() ------ 판별자를 학습 모드로 설정
for epoch in range(epochs):
loss_g = 0.0 ------ 생성자 오차를 추적(저장)하기 위한 변수
loss_d = 0.0 ------ 판별자 오차를 추적(저장)하기 위한 변수
for idx, data in tqdm(enumerate(train_loader), total=int(len(train_dataset)/train_loader.batch_size)):
image, _ = data ------ 학습을 위한 이미지 데이터를 가져옵니다.
image = image.to(device) ------ 데이터셋이 CPU/GPU 장치를 사용하도록 지정
b_size = len(image)
for step in range(k): ------ k(1) 스텝 수에 따라 판별자를 실행, 이때 k 수를 증가시킬 수 있지만 학습 시간이 길어질 수 있으므로 주의하세요.
data_fake = generator(torch.randn(b_size, nz).to(device)).detach() ------ ①
data_real = image
loss_d += train_discriminator(optim_d, data_real, data_fake) ------ ①′
data_fake = generator(torch.randn(b_size, nz).to(device))
loss_g += train_generator(optim_g, data_fake) ------ 생성자 학습
generated_img = generator(torch.randn(b_size, nz).to(device)).cpu().detach() ------ 생성자를 이용하여 새로운 이미지를 생성하고 CPU 장치를 이용하여 디스크에 저장
generated_img = make_grid(generated_img) ------ 이미지를 그리드 형태로 표현
save_generator_image(generated_img, f"../chap13/img/gen_img{epoch}.png") ------ 생성된 이미지(텐서)를 디스크에 저장
images.append(generated_img)
epoch_loss_g = loss_g / idx ------ 에포크에 대한 총 생성자 오차 계산
epoch_loss_d = loss_d / idxH ------ 에포크에 대한 총 판별자 오차 계산
losses_g.append(epoch_loss_g)
losses_d.append(epoch_loss_d)
print(f"Epoch {epoch} of {epochs}")
print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")