더북(TheBook)

버트 모델을 학습시키기 위한 함수를 정의합니다. 코드가 상당히 길어 보이지만 복잡하지는 않습니다. 옵티마이저 설정, 에포크 주기별로 오차를 기록하는 것이 학습 과정의 전부입니다. 단지 훈련과 검증이 하나의 함수에 포함되어 있어 길어 보일 뿐입니다.

코드 10-47 모델 훈련 함수 정의

def train(model,
    optimizer,
    criterion=nn.BCELoss(), ------ 영화 리뷰는 좋고 나쁨만 있으므로 BinaryCrossEntropy(BCELoss)를 사용
    num_epochs=5, ------ 에포크를 5로 사용
    eval_every=len(train_loader)//2,
    best_valid_loss=float("Inf")):

total_correct = 0.0
total_len = 0.0
running_loss = 0.0
valid_running_loss = 0.0
global_step = 0
train_loss_list = []
valid_loss_list = []
global_steps_list = []

model.train() ------ 모델 훈련
for epoch in range(num_epochs):
    for text, label in train_loader:
        optimizer.zero_grad()
        encoded_list = [tokenizer.encode(t, add_special_tokens=True) for t in text]
        padded_list = [e + [0] * (512-len(e)) for e in encoded_list] ------ 인코딩 결과에 제로패딩(zero-padding)을 적용
        sample = torch.tensor(padded_list)
        sample, label = sample.to(device), label.to(device)
        labels = torch.tensor(label)
        outputs = model(sample, labels=labels)
        loss, logits = outputs

        pred = torch.argmax(F.softmax(logits), dim=1) ------ ①
        correct = pred.eq(labels)
        total_correct += correct.sum().item()
        total_len += len(labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        global_step += 1

        if global_step % eval_every == 0: ------ 모델 평가
            model.eval()
            with torch.no_grad():
                for text, label in valid_loader:
                    encoded_list = [tokenizer.encode(t, add_special_tokens=True) for t in text]
                    padded_list = [e + [0] * (512-len(e)) for e in encoded_list]
                    sample = torch.tensor(padded_list)
                    sample, label = sample.to(device), label.to(device)
                    labels = torch.tensor(label)
                    outputs = model(sample, labels=labels)
                    loss, logits = outputs
                    valid_running_loss += loss.item()

            average_train_loss = running_loss / eval_every
            average_valid_loss = valid_running_loss / len(valid_loader)
            train_loss_list.append(average_train_loss)
            valid_loss_list.append(average_valid_loss)
            global_steps_list.append(global_step)

            running_loss = 0.0
            valid_running_loss = 0.0
            model.train()

            print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
.format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader), average_train_loss, average_valid_loss))

            if best_valid_loss > average_valid_loss: ------ ②
                best_valid_loss = average_valid_loss
                save_checkpoint('../chap10/data/model.pt', model, best_valid_loss) ------ 오차가 작아지면 모델 저장
                save_metrics('../chap10/data/metrics.pt', train_loss_list, valid_loss_list, global_steps_list) ------ 평가에 사용된 훈련 오차, 검증 오차, 에포크(스텝)를 저장

    save_metrics('../chap10/data/metrics.pt', train_loss_list, valid_loss_list, global_steps_list) ------ 최종으로 사용된 훈련 오차, 검증 오차, 에포크(스텝)를 저장
    print('훈련 종료!')
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.