이제 평균과 표준편차가 주어졌을 때 잠재 벡터 z를 만들기 위해 reparameterization()이라는 이름으로 함수를 생성해 보겠습니다.
코드 13-15 변형 오토인코더 네트워크
class Model(nn.Module):
def __init__(self, Encoder, Decoder):
super(Model, self).__init__()
self.Encoder = Encoder
self.Decoder = Decoder
def reparameterization(self, mean, var): ------ ①
epsilon = torch.randn_like(var).to(device)
z = mean + var * epsilon ------ z 값 구하기
return z
def forward(self, x):
mean, log_var = self.Encoder(x) ------ ②
z = self.reparameterization(mean, torch.exp(0.5 * log_var))
x_hat = self.Decoder(z)
return x_hat, mean, log_var ------ 디코더 결과와 평균, 표준편차(log를 취한 표준편차)를 반환
① reparameterization() 함수는 z 벡터를 샘플링하기 위한 용도입니다. z는 가우시안 분포라고 가정했기 때문에 인코더에서 받아 온 평균(μ)과 표준편차(σ)를 이용하여 z를 생성합니다. 그리고 z 벡터를 디코더에 다시 통과시켜서 입력과 동일한 데이터(x')를 만들어 내는 작업을 합니다.
② 인코더에서 받아 온 평균과 표준편차를 이용하지만 표준편차는 값을 그대로 사용하지 않습니다. 값이 음수가 되지 않도록 로그(log)를 취하는데, 다음과 같은 방식을 취합니다.
따라서 변수 이름도 var에서 log_var로 변경했습니다.