더북(TheBook)

다음은 테스트 데이터가 몇 건인지 보여 줍니다. 테스트용 데이터로 부족하지만 실습을 진행할 환경을 고려하여 일부만 사용합니다.

98

테스트용 데이터 준비가 완료되었습니다. 테스트 데이터 평가를 위한 함수를 생성합니다.

코드 5-24 테스트 데이터 평가 함수 생성

def eval_model(model, dataloaders, device):
    since = time.time()
    acc_history = []
    best_acc = 0.0

    saved_models = glob.glob('../chap05/data/catanddog/' + '*.pth') ------ ①
    saved_models.sort() ------ 불러온 .pth 파일들을 정렬
    print('saved_model', saved_models)

    for model_path in saved_models:
        print('Loading model', model_path)

        model.load_state_dict(torch.load(model_path))
        model.eval()
        model.to(device)
        running_corrects = 0

        for inputs, labels in dataloaders: ------ 테스트 반복
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad(): ------ autograd를 사용하지 않겠다는 의미
                 outputs = model(inputs) ------ 데이터를 모델에 적용한 결과를 outputs에 저장

            _, preds = torch.max(outputs.data, 1) ------ ②
            preds[preds >= 0.5] = 1 ------ torch.max로 출력된 값이 0.5보다 크면 올바르게 예측
            preds[preds < 0.5] = 0 ------ torch.max로 출력된 값이 0.5보다 작으면 틀리게 예측
            running_corrects += preds.eq(labels.cpu()).int().sum() ------ ③

        epoch_acc = running_corrects.double() / len(dataloaders.dataset) ------ 테스트 데이터의 정확도 계산
        print('Acc: {:.4f}'.format(epoch_acc))

        if epoch_acc > best_acc:
            best_acc = epoch_acc
            acc_history.append(epoch_acc.item())
            print()

        time_elapsed = time.time() - since
        print('Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best Acc: {:4f}'.format(best_acc))

        return acc_history ------ 계산된 정확도 반환
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.