더북(TheBook)

이번에는 배치 정규화가 포함된 네트워크를 구축합니다.

코드 8-9 배치 정규화가 포함된 네트워크

class BNNet(nn.Module):
    def __init__(self):
        super(BNNet, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(784, 48),
            nn.BatchNorm1d(48), ------ ①
            nn.ReLU(),
            nn.Linear(48, 24),
            nn.BatchNorm1d(24),
            nn.ReLU(),
            nn.Linear(24, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

① 배치 정규화가 적용되는 부분입니다. BatchNorm1d에서 사용되는 파라미터는 특성 개수로 이전 계층의 출력 채널이 됩니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.