더북(TheBook)

② 가짜 데이터에 대한 레이블을 생성합니다.

ⓐ 0 값을 가진 (b_size×1) 크기의 텐서를 생성

이제 생성자 학습을 위한 함수를 정의할 텐데, 상대적으로 판별자의 네트워크보다는 간단합니다.

코드 13-31 생성자 학습을 위한 함수

def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = torch.ones(b_size, 1).to(device) ------ ①
    optimizer.zero_grad()
    output = discriminator(data_fake)
    loss = criterion(output, real_label)
    loss.backward()
    optimizer.step()
    return loss

① 생성자 네트워크에서는 가짜 데이터만 사용하고 있는데, 생성자 입장에서는 가짜 데이터가 실제로 진짜라는 것에 주의할 필요가 있습니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.