모델을 학습시키는 방법에 대한 함수를 정의합니다.
코드 6-84 모델 학습 함수 정의
def train(model, iterator, optimizer, criterion, scheduler, device):
epoch_loss = 0
epoch_acc_1 = 0
epoch_acc_5 = 0
model.train()
for (x, y) in iterator:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred[0], y)
acc_1, acc_5 = calculate_topk_accuracy(y_pred[0], y)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc_1 += acc_1.item() ------ 모델이 첫 번째로 예측한 레이블이 붙여집니다.
epoch_acc_5 += acc_5.item() ------ 이미지에 정확한 레이블이 붙여질 것이기 때문에 정확도가 100%일 것입니다.
epoch_loss /= len(iterator)
epoch_acc_1 /= len(iterator)
epoch_acc_5 /= len(iterator)
return epoch_loss, epoch_acc_1, epoch_acc_5