테스트 데이터셋을 이용해서 모델을 평가하기 위한 함수를 정의합니다.
코드 13-19 모델 평가 함수 정의
def test(epoch, model, test_loader):
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (x, _) in enumerate(test_loader):
x = x.view(batch_size, x_dim)
x = x.to(device)
x_hat, mean, log_var = model(x)
BCE, KLD = loss_function(x, x_hat, mean, log_var)
loss = BCE + KLD
writer.add_scalar("Test/Reconstruction Error", BCE.item(), batch_idx +
epoch * (len(test_loader.dataset)/batch_size)) ------ 테스트 데이터셋에 대해서도 오차를 로그에 저장
writer.add_scalar("Test/KL-Divergence", KLD.item(), batch_idx + epoch *
(len(test_loader.dataset)/batch_size))
writer.add_scalar("Test/Total Loss", loss.item(), batch_idx + epoch *
(len(test_loader.dataset)/batch_size))
test_loss += loss.item()
if batch_idx == 0:
n = min(x.size(0), 8)
comparison = torch.cat([x[:n], x_hat.view(batch_size, x_dim)[:n]])
grid = torchvision.utils.make_grid(comparison.cpu())
writer.add_image("Test image - Above: Real data, below: reconstruction data", grid, epoch)