더북(TheBook)

모델을 학습시킬 함수를 정의합니다. 학습 용도이기 때문에 model.train()을 사용합니다.

코드 6-16 모델 학습 함수 정의

def train_model(model, dataloader_dict, criterion, optimizer, num_epoch):
    since = time.time()
    best_acc = 0.0
 
    for epoch in range(num_epoch): ------ epoch를 10으로 설정했으므로 10회 반복
        print('Epoch {}/{}'.format(epoch+1, num_epoch))
        print('-'*20)
 
        for phase in ['train', 'val']:
        if phase == 'train':
            model.train() ------ 모델을 학습시키겠다는 의미
        else:
            model.eval()

        epoch_loss = 0.0
        epoch_corrects = 0

        for inputs, labels in tqdm(dataloader_dict[phase]): ------ 여기에서 dataloader_dict는 훈련 데이터셋(train_loader)을 의미
            inputs = inputs.to(device) ------ 훈련 데이터셋을 CPU에 할당
            labels = labels.to(device)
            optimizer.zero_grad() ------ 역전파 단계를 실행하기 전에 기울기(gradient)를 0으로 초기화

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels) ------ 손실 함수를 이용한 오차 계산

                if phase == 'train':
                    loss.backward() ------ 모델의 학습 가능한 모든 파라미터에 대해 기울기를 계산
                    optimizer.step() ------ optimizer의 step 함수를 호출하면 파라미터를 갱신

                epoch_loss += loss.item() * inputs.size(0) ------ ①
                epoch_corrects += torch.sum(preds == labels.data) ------ 정답과 예측이 일치하면 그것의 합계를 epoch_corrects에 저장

        epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset) ------ 최종 오차 계산(오차를 데이터셋의 길이(개수)로 나누어서 계산)
        epoch_acc = epoch_corrects.double() / len(dataloader_dict[phase].dataset) ------ 최종 정확도(epoch_corrects를 데이터셋의 길이(개수)로 나누어서 계산) 

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        if phase == 'val' and epoch_acc > best_acc: ------ 검증 데이터셋에 대한 가장 최적의 정확도를 저장
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    return model
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.