훈련 데이터셋을 이용한 모델 학습 함수를 정의합니다.
코드 8-31 모델 학습 함수
def training(model, train_dataloader, train_dataset, optimizer, criterion):
print('Training')
model.train()
train_running_loss = 0.0
train_running_correct = 0
counter = 0
total = 0
prog_bar = tqdm(enumerate(train_dataloader), total=int(len(train_dataset)/train_dataloader.batch_size)) ------ 훈련 진행 과정을 시각적으로 표현
for i, data in prog_bar:
counter += 1
data, target = data[0].to(device), data[1].to(device)
total += target.size(0)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, target)
train_running_loss += loss.item()
_, preds = torch.max(outputs.data, 1)
train_running_correct += (preds == target).sum().item()
loss.backward()
optimizer.step()
train_loss = train_running_loss / counter
train_accuracy = 100. * train_running_correct / total
return train_loss, train_accuracy