더북(TheBook)

이제 평균과 표준편차가 주어졌을 때 잠재 벡터 z를 만들기 위해 reparameterization()이라는 이름으로 함수를 생성해 보겠습니다.

코드 13-15 변형 오토인코더 네트워크

class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var): ------ ①
        epsilon = torch.randn_like(var).to(device)
        z = mean + var * epsilon ------ z 값 구하기
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x) ------ ②
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)
        return x_hat, mean, log_var ------ 디코더 결과와 평균, 표준편차(log를 취한 표준편차)를 반환

reparameterization() 함수는 z 벡터를 샘플링하기 위한 용도입니다. z는 가우시안 분포라고 가정했기 때문에 인코더에서 받아 온 평균(μ)과 표준편차(σ)를 이용하여 z를 생성합니다. 그리고 z 벡터를 디코더에 다시 통과시켜서 입력과 동일한 데이터(x')를 만들어 내는 작업을 합니다.

② 인코더에서 받아 온 평균과 표준편차를 이용하지만 표준편차는 값을 그대로 사용하지 않습니다. 값이 음수가 되지 않도록 로그(log)를 취하는데, 다음과 같은 방식을 취합니다.

따라서 변수 이름도 var에서 log_var로 변경했습니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.