② 가짜 데이터에 대한 레이블을 생성합니다.
ⓐ 0 값을 가진 (b_size×1) 크기의 텐서를 생성
이제 생성자 학습을 위한 함수를 정의할 텐데, 상대적으로 판별자의 네트워크보다는 간단합니다.
코드 13-31 생성자 학습을 위한 함수
def train_generator(optimizer, data_fake):
b_size = data_fake.size(0)
real_label = torch.ones(b_size, 1).to(device) ------ ①
optimizer.zero_grad()
output = discriminator(data_fake)
loss = criterion(output, real_label)
loss.backward()
optimizer.step()
return loss
① 생성자 네트워크에서는 가짜 데이터만 사용하고 있는데, 생성자 입장에서는 가짜 데이터가 실제로 진짜라는 것에 주의할 필요가 있습니다.