이번에는 배치 정규화가 포함된 네트워크를 구축합니다.
코드 8-9 배치 정규화가 포함된 네트워크
class BNNet(nn.Module):
def __init__(self):
super(BNNet, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(784, 48),
nn.BatchNorm1d(48), ------ ①
nn.ReLU(),
nn.Linear(48, 24),
nn.BatchNorm1d(24),
nn.ReLU(),
nn.Linear(24, 10)
)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
① 배치 정규화가 적용되는 부분입니다. BatchNorm1d에서 사용되는 파라미터는 특성 개수로 이전 계층의 출력 채널이 됩니다.