생성자 네트워크가 완료되었고, 이제 판별자 네트워크를 생성해 보겠습니다. 판별자는 이진 분류자라는 것을 고려하여 신경망을 구축해야 합니다.
코드 13-26 판별자 네트워크 생성
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.n_input = 784 ------ 판별자의 입력 크기
self.main = nn.Sequential( ------ 판별자 역시 선형 계층과 리키렐루 활성화 함수로 구성
nn.Linear(self.n_input, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(-1, 784)
return self.main(x) ------ 이미지가 진짜인지 가짜인지를 분류하는 값을 반환