이제 모델 학습에 필요한 함수를 정의합니다.
코드 13-18 모델 학습 함수 정의
saved_loc = 'scalar/' ------ 텐서보드에서 사용할 경로
writer = SummaryWriter(saved_loc) ------ ①
model.train()
def train(epoch, model, train_loader, optimizer):
train_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
x = x.view(batch_size, x_dim)
x = x.to(device)
optimizer.zero_grad()
x_hat, mean, log_var = model(x)
BCE, KLD = loss_function(x, x_hat, mean, log_var)
loss = BCE + KLD
writer.add_scalar("Train/Reconstruction Error", BCE.item(), batch_idx + epoch *
(len(train_loader.dataset)/batch_size)) ------ ②
writer.add_scalar("Train/KL-Divergence", KLD.item(), batch_idx + epoch *
(len(train_loader.dataset)/batch_size))
writer.add_scalar("Train/Total Loss", loss.item(), batch_idx + epoch *
(len(train_loader.dataset)/batch_size))
train_loss += loss.item()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
epoch, batch_idx * len(x), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(x)))
print("======> Epoch: {} Average loss: {:.4f}".format(
epoch, train_loss / len(train_loader.dataset)))