더북(TheBook)

모델을 학습시키는 방법에 대한 함수를 정의합니다.

코드 6-84 모델 학습 함수 정의

def train(model, iterator, optimizer, criterion, scheduler, device):
    epoch_loss = 0
    epoch_acc_1 = 0
    epoch_acc_5 = 0

    model.train()
    for (x, y) in iterator:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred[0], y)

        acc_1, acc_5 = calculate_topk_accuracy(y_pred[0], y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc_1 += acc_1.item() ------ 모델이 첫 번째로 예측한 레이블이 붙여집니다.
        epoch_acc_5 += acc_5.item() ------ 이미지에 정확한 레이블이 붙여질 것이기 때문에 정확도가 100%일 것입니다.

    epoch_loss /= len(iterator)
    epoch_acc_1 /= len(iterator)
    epoch_acc_5 /= len(iterator)
    return epoch_loss, epoch_acc_1, epoch_acc_5
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.