모델을 학습시킬 함수를 정의합니다. 학습 용도이기 때문에 model.train()을 사용합니다.
코드 6-16 모델 학습 함수 정의
def train_model(model, dataloader_dict, criterion, optimizer, num_epoch):
since = time.time()
best_acc = 0.0
for epoch in range(num_epoch): ------ epoch를 10으로 설정했으므로 10회 반복
print('Epoch {}/{}'.format(epoch+1, num_epoch))
print('-'*20)
for phase in ['train', 'val']:
if phase == 'train':
model.train() ------ 모델을 학습시키겠다는 의미
else:
model.eval()
epoch_loss = 0.0
epoch_corrects = 0
for inputs, labels in tqdm(dataloader_dict[phase]): ------ 여기에서 dataloader_dict는 훈련 데이터셋(train_loader)을 의미
inputs = inputs.to(device) ------ 훈련 데이터셋을 CPU에 할당
labels = labels.to(device)
optimizer.zero_grad() ------ 역전파 단계를 실행하기 전에 기울기(gradient)를 0으로 초기화
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) ------ 손실 함수를 이용한 오차 계산
if phase == 'train':
loss.backward() ------ 모델의 학습 가능한 모든 파라미터에 대해 기울기를 계산
optimizer.step() ------ optimizer의 step 함수를 호출하면 파라미터를 갱신
epoch_loss += loss.item() * inputs.size(0) ------ ①
epoch_corrects += torch.sum(preds == labels.data) ------ 정답과 예측이 일치하면 그것의 합계를 epoch_corrects에 저장
epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset) ------ 최종 오차 계산(오차를 데이터셋의 길이(개수)로 나누어서 계산)
epoch_acc = epoch_corrects.double() / len(dataloader_dict[phase].dataset) ------ 최종 정확도(epoch_corrects를 데이터셋의 길이(개수)로 나누어서 계산)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
if phase == 'val' and epoch_acc > best_acc: ------ 검증 데이터셋에 대한 가장 최적의 정확도를 저장
best_acc = epoch_acc
best_model_wts = model.state_dict()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
return model