모델에서 사용할 옵티마이저와 손실 함수를 지정합니다.
코드 8-13 옵티마이저, 손실 함수 지정
loss_fn = nn.CrossEntropyLoss()
opt = optim.SGD(model.parameters(), lr=0.01)
opt_bn = optim.SGD(model_bn.parameters(), lr=0.01)
이제 모델을 학습시켜 보겠습니다.
코드 8-14 모델 학습
loss_arr = []
loss_bn_arr = []
max_epochs = 2
for epoch in range(max_epochs):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
opt.zero_grad() ------ 배치 정규화가 적용되지 않은 모델의 학습
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
opt.step()
opt_bn.zero_grad() ------ 배치 정규화가 적용된 모델의 학습
outputs_bn = model_bn(inputs)
loss_bn = loss_fn(outputs_bn, labels)
loss_bn.backward()
opt_bn.step()
loss_arr.append(loss.item())
loss_bn_arr.append(loss_bn.item())
plt.plot(loss_arr, 'yellow', label='Normal')
plt.plot(loss_bn_arr, 'blue', label='BatchNorm')
plt.legend()
plt.show()