시각화에 사용될 손실(loss) 정보를 데이터프레임(DataFrame)에 저장합니다.

    코드 13-14 손실 정보 저장

    losses = pd.DataFrame(columns=['recon_loss', 'latent_loss']) ------ 손실 정보를 데이터프레임(dataframe)에 저장

    데이터셋이 준비되었고, 네트워크가 생성되었기 때문에 이제 모델을 훈련시킵니다.

    코드 13-15 모델 훈련

    n_epochs = 50
    
    for epoch in range(n_epochs):
        for batch, train_x in tqdm(
            zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES):
            model.train(train_x) ------  훈련 데이터셋을 사용하여 훈련
            loss = []
    
        for batch, test_x in tqdm(
            zip(range(N_TEST_BATCHES), test_dataset), total=N_TEST_BATCHES):
            loss.append(model.loss_function(train_x))
        losses.loc[len(losses)] = np.mean(loss, axis=0)
        display.clear_output()
        print(
            "Epoch: {} | recon_loss: {} | latent_loss: {}".format(
                epoch, losses.recon_loss.values[-1], losses.latent_loss.values[-1]
            ) ------ 재구성 오차(reconstruction_loss)와 인코더-디코더 사이의 오차(latent_loss) 출력
        )
        plot_reconstruction(model, example_data) ------ 결과를 시각화
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.