판별자를 학습시키기 위한 함수를 정의합니다. 판별자의 학습은 진짜 데이터의 레이블과 가짜 데이터의 레이블을 모두 이용하여 학습합니다.
코드 13-30 판별자 학습을 위한 함수
def train_discriminator(optimizer, data_real, data_fake):
b_size = data_real.size(0) ------ 배치 크기 정보 얻기
real_label = torch.ones(b_size, 1).to(device) ------ ①
fake_label = torch.zeros(b_size, 1).to(device) ------ ②
optimizer.zero_grad()
output_real = discriminator(data_real)
loss_real = criterion(output_real, real_label) ------ 진짜 데이터를 판별자에 제공하여 학습한 결과와 진짜 데이터의 레이블을 이용하여 오차를 계산
output_fake = discriminator(data_fake)
loss_fake = criterion(output_fake, fake_label) ------ 가짜 데이터를 판별자에 제공하여 학습한 결과와 가짜 데이터의 레이블을 이용하여 오차를 계산
loss_real.backward()
loss_fake.backward()
optimizer.step()
return loss_real + loss_fake ------ 진짜 데이터와 가짜 데이터의 오차가 합쳐진 최종 오차를 반환
① GAN에서는 모델 훈련을 위해 진짜 이미지와 생성자에서 생성한 가짜 이미지가 필요합니다. 그뿐만 아니라 레이블 정보도 필요한데, 레이블 정보는 텐서 형태를 가져야 하며 배치 크기와도 동일해야 합니다. 먼저 진짜 데이터에 대한 레이블을 생성합니다.
ⓐ 1 값을 가진 (b_size×1) 크기의 텐서를 생성