더북(TheBook)

이제 for 구문을 이용하여 200회 에포크만큼 모델을 학습시킵니다.

코드 13-32 모델 학습

generator.train() ------ 생성자를 학습 모드로 설정
discriminator.train() ------ 판별자를 학습 모드로 설정

for epoch in range(epochs):
    loss_g = 0.0 ------ 생성자 오차를 추적(저장)하기 위한 변수
    loss_d = 0.0 ------ 판별자 오차를 추적(저장)하기 위한 변수
    for idx, data in tqdm(enumerate(train_loader), total=int(len(train_dataset)/train_loader.batch_size)):
        image, _ = data ------ 학습을 위한 이미지 데이터를 가져옵니다.
        image = image.to(device) ------ 데이터셋이 CPU/GPU 장치를 사용하도록 지정
        b_size = len(image)
        for step in range(k): ------ k(1) 스텝 수에 따라 판별자를 실행, 이때 k 수를 증가시킬 수 있지만 학습 시간이 길어질 수 있으므로 주의하세요.

            data_fake = generator(torch.randn(b_size, nz).to(device)).detach() ------ ①
            data_real = image
            loss_d += train_discriminator(optim_d, data_real, data_fake) ------ ①′
        data_fake = generator(torch.randn(b_size, nz).to(device))
        loss_g += train_generator(optim_g, data_fake) ------ 생성자 학습
    generated_img = generator(torch.randn(b_size, nz).to(device)).cpu().detach() ------ 생성자를 이용하여 새로운 이미지를 생성하고 CPU 장치를 이용하여 디스크에 저장
    generated_img = make_grid(generated_img) ------ 이미지를 그리드 형태로 표현
    save_generator_image(generated_img, f"../chap13/img/gen_img{epoch}.png") ------ 생성된 이미지(텐서)를 디스크에 저장
    images.append(generated_img)
    epoch_loss_g = loss_g / idx ------ 에포크에 대한 총 생성자 오차 계산
    epoch_loss_d = loss_d / idxH ------ 에포크에 대한 총 판별자 오차 계산
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)

    print(f"Epoch {epoch} of {epochs}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.