다음은 테스트 데이터가 몇 건인지 보여 줍니다. 테스트용 데이터로 부족하지만 실습을 진행할 환경을 고려하여 일부만 사용합니다.
98
테스트용 데이터 준비가 완료되었습니다. 테스트 데이터 평가를 위한 함수를 생성합니다.
코드 5-24 테스트 데이터 평가 함수 생성
def eval_model(model, dataloaders, device):
since = time.time()
acc_history = []
best_acc = 0.0
saved_models = glob.glob('../chap05/data/catanddog/' + '*.pth') ------ ①
saved_models.sort() ------ 불러온 .pth 파일들을 정렬
print('saved_model', saved_models)
for model_path in saved_models:
print('Loading model', model_path)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
running_corrects = 0
for inputs, labels in dataloaders: ------ 테스트 반복
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad(): ------ autograd를 사용하지 않겠다는 의미
outputs = model(inputs) ------ 데이터를 모델에 적용한 결과를 outputs에 저장
_, preds = torch.max(outputs.data, 1) ------ ②
preds[preds >= 0.5] = 1 ------ torch.max로 출력된 값이 0.5보다 크면 올바르게 예측
preds[preds < 0.5] = 0 ------ torch.max로 출력된 값이 0.5보다 작으면 틀리게 예측
running_corrects += preds.eq(labels.cpu()).int().sum() ------ ③
epoch_acc = running_corrects.double() / len(dataloaders.dataset) ------ 테스트 데이터의 정확도 계산
print('Acc: {:.4f}'.format(epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
acc_history.append(epoch_acc.item())
print()
time_elapsed = time.time() - since
print('Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best Acc: {:4f}'.format(best_acc))
return acc_history ------ 계산된 정확도 반환