모델에 대한 평가 함수를 정의합니다.
코드 6-85 모델 평가 함수 정의
def evaluate(model, iterator, criterion, device):
epoch_loss = 0
epoch_acc_1 = 0
epoch_acc_5 = 0
model.eval()
with torch.no_grad():
for (x, y) in iterator:
x = x.to(device)
y = y.to(device)
y_pred = model(x)
loss = criterion(y_pred[0], y)
acc_1, acc_5 = calculate_topk_accuracy(y_pred[0], y)
epoch_loss += loss.item()
epoch_acc_1 += acc_1.item()
epoch_acc_5 += acc_5.item()
epoch_loss /= len(iterator)
epoch_acc_1 /= len(iterator)
epoch_acc_5 /= len(iterator)
return epoch_loss, epoch_acc_1, epoch_acc_5