어떤가요? 모델의 네트워크가 한눈에 들어오지 않나요?
그럼 이제 모델 학습을 진행할 함수를 정의해 보겠습니다. 이 부분 역시 LeNet에서 사용했던 코드와 같기 때문에 설명은 생략합니다.
코드 6-33 모델 학습 함수 정의
def train_model(model, dataloader_dict, criterion, optimizer, num_epoch):
since = time.time()
best_acc = 0.0
for epoch in range(num_epoch):
print('Epoch {}/{}'.format(epoch+1, num_epoch))
print('-'*20)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
epoch_loss = 0.0
epoch_corrects = 0
for inputs, labels in tqdm(dataloader_dict[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
epoch_loss += loss.item() * inputs.size(0)
epoch_corrects += torch.sum(preds == labels.data)
epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)
epoch_acc = epoch_corrects.double() / len(dataloader_dict[phase].dataset)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
return model