모델에서 사용할 옵티마이저와 손실 함수를 지정합니다.

    코드 8-13 옵티마이저, 손실 함수 지정

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.SGD(model.parameters(), lr=0.01)
    opt_bn = optim.SGD(model_bn.parameters(), lr=0.01)

    이제 모델을 학습시켜 보겠습니다.

    코드 8-14 모델 학습

    loss_arr = []
    loss_bn_arr = []
    max_epochs = 2
    
    for epoch in range(max_epochs):
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            opt.zero_grad() ------ 배치 정규화가 적용되지 않은 모델의 학습
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            opt.step()
    
            opt_bn.zero_grad() ------ 배치 정규화가 적용된 모델의 학습
            outputs_bn = model_bn(inputs)
            loss_bn = loss_fn(outputs_bn, labels)
            loss_bn.backward()
            opt_bn.step()
    
            loss_arr.append(loss.item())
            loss_bn_arr.append(loss_bn.item())
    
        plt.plot(loss_arr, 'yellow', label='Normal')
        plt.plot(loss_bn_arr, 'blue', label='BatchNorm')
        plt.legend()
        plt.show()
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.