더북(TheBook)

데이터셋과 모델에 대한 준비가 완료되었습니다. 이제 모델을 학습시켜 보겠습니다.

코드 8-33 모델 학습

train_loss, train_accuracy = [], [] ------ 훈련 데이터셋을 이용한 모델 학습 결과(오차, 정확도)를 저장하기 위한 변수(리스트 형태를 갖습니다)

val_loss, val_accuracy = [], [] ------ 검증 데이터셋을 이용한 모델 성능 결과(오차, 정확도)를 저장하기 위한 변수(리스트 형태를 갖습니다)

start = time.time()
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = training(
        model, train_dataloader, train_dataset, optimizer, criterion
    )
    val_epoch_loss, val_epoch_accuracy = validate(
        model, val_dataloader, val_dataset, criterion
    )
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    if args['lr_scheduler']: ------ 인수 값이 lr_scheduler이면 다음을 실행
        lr_scheduler(val_epoch_loss)
    if args['early_stopping']: ------ 인수 값이 early_stopping이면 다음을 실행
        early_stopping(val_epoch_loss, model)
        if early_stopping.early_stop:
            break
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f}")
    print(f'Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}')
end = time.time()
print(f"Training time: {(end-start)/60:.3f} minutes")
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.