더북(TheBook)

2 진위를 가려내는 장치, 판별자

 

이제 생성자에서 넘어온 이미지가 가짜인지 진짜인지를 판별해 주는 장치인 판별자(discriminator)를 만들 차례입니다. 이 부분은 컨볼루션 신경망의 구조를 그대로 가지고 와서 만들면 됩니다. 컨볼루션 신경망이란 원래 무언가를(예를 들어 개와 고양이 사진을) 구별하는 데 최적화된 알고리즘이기 때문에 그 목적 그대로 사용하면 되는 것이지요.

진짜(1) 아니면 가짜(0), 둘 중 하나를 결정하는 문제이므로 컴파일 부분은 14장에서 사용된 이진 로스 함수(binary_crossentropy)와 최적화 함수(adam)를 그대로 쓰겠습니다. 16장에서 배웠던 드롭아웃(Dropout(0.3))도 다시 사용하고, 앞 절에서 다룬 배치 정규화와 패딩도 그대로 넣어 줍니다.

주의할 점은 이 판별자는 가짜인지 진짜인지를 판별만 해 줄 뿐, 자기 자신이 학습을 해서는 안 된다는 것입니다. 판별자가 얻은 가중치는 판별자 자신이 학습하는 데 쓰이는 것이 아니라 생성자로 넘겨주어 생성자가 업데이트된 이미지를 만들도록 해야 합니다. 따라서 판별자를 만들 때는 가중치를 저장하는 학습 기능을 꺼 주어야 합니다.

모든 과정을 코드로 정리해 보면 다음과 같습니다.

# 모델 이름을 discriminator로 정하고 Sequential() 함수를 호출합니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same")) ----- ➊
discriminator.add(Activation(LeakyReLU(0.2))) ----- ➋
discriminator.add(Dropout(0.3)) ----- ➌
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same")) ----- ➍
discriminator.add(Activation(LeakyReLU(0.2))) ----- ➎
discriminator.add(Dropout(0.3)) ----- ➏
discriminator.add(Flatten()) ----- ➐
discriminator.add(Dense(1, activation='sigmoid')) ----- ➑
discriminator.compile(loss='binary_crossentropy', optimizer='adam') ----- ➒
discriminator.trainable = False ----- ➓
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.