시각화에 사용될 손실(loss) 정보를 데이터프레임(DataFrame)에 저장합니다.
코드 13-14 손실 정보 저장
losses = pd.DataFrame(columns=['recon_loss', 'latent_loss']) ------ 손실 정보를 데이터프레임(dataframe)에 저장
데이터셋이 준비되었고, 네트워크가 생성되었기 때문에 이제 모델을 훈련시킵니다.
코드 13-15 모델 훈련
n_epochs = 50
for epoch in range(n_epochs):
for batch, train_x in tqdm(
zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES):
model.train(train_x) ------ 훈련 데이터셋을 사용하여 훈련
loss = []
for batch, test_x in tqdm(
zip(range(N_TEST_BATCHES), test_dataset), total=N_TEST_BATCHES):
loss.append(model.loss_function(train_x))
losses.loc[len(losses)] = np.mean(loss, axis=0)
display.clear_output()
print(
"Epoch: {} | recon_loss: {} | latent_loss: {}".format(
epoch, losses.recon_loss.values[-1], losses.latent_loss.values[-1]
) ------ 재구성 오차(reconstruction_loss)와 인코더-디코더 사이의 오차(latent_loss) 출력
)
plot_reconstruction(model, example_data) ------ 결과를 시각화