다음은 모델 성능을 검증하기 위한 함수입니다.
코드 8-32 모델 검증 함수
def validate(model, test_dataloader, val_dataset, criterion):
print('Validating')
model.eval()
val_running_loss = 0.0
val_running_correct = 0
counter = 0
total = 0
prog_bar = tqdm(enumerate(test_dataloader), total=int(len(val_dataset)/test_dataloader.batch_size)) ------ 모델 검증 과정을 시각적으로 표현
with torch.no_grad():
for i, data in prog_bar:
counter += 1
data, target = data[0].to(device), data[1].to(device)
total += target.size(0)
outputs = model(data)
loss = criterion(outputs, target)
val_running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
val_running_correct += (preds == target).sum().item()
val_loss = val_running_loss / counter
val_accuracy = 100. * val_running_correct / total
return val_loss, val_accuracy