더북(TheBook)

이제 모델 학습에 필요한 함수를 정의합니다.

코드 13-18 모델 학습 함수 정의

saved_loc = 'scalar/' ------ 텐서보드에서 사용할 경로
writer = SummaryWriter(saved_loc) ------ ①

model.train()

def train(epoch, model, train_loader, optimizer):
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, x_dim)
        x = x.to(device)

        optimizer.zero_grad()
        x_hat, mean, log_var = model(x)
        BCE, KLD = loss_function(x, x_hat, mean, log_var)
        loss = BCE + KLD
        writer.add_scalar("Train/Reconstruction Error", BCE.item(), batch_idx + epoch *
                         (len(train_loader.dataset)/batch_size)) ------ ②
        writer.add_scalar("Train/KL-Divergence", KLD.item(), batch_idx + epoch *
                         (len(train_loader.dataset)/batch_size))
        writer.add_scalar("Train/Total Loss", loss.item(), batch_idx + epoch *
                         (len(train_loader.dataset)/batch_size))

        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                  epoch, batch_idx * len(x), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader),
                  loss.item() / len(x)))

    print("======> Epoch: {} Average loss: {:.4f}".format(
          epoch, train_loss / len(train_loader.dataset)))
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.