더북(TheBook)

다음은 최적화된 모델을 저장하기 위한 함수입니다.

코드 10-46 최적화 모델 저장

def save_checkpoint(save_path, model, valid_loss): ------ 모델 평가를 위해 훈련 과정을 저장
    if save_path == None:
        return
    state_dict = {'model_state_dict': model.state_dict(), 'valid_loss': valid_loss}
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model): ------ save_checkpoint 함수에서 저장된 모델을 가져옵니다.
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']

def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): ------ 훈련, 검증에 대한 오차와 에포크를 저장
    if save_path == None:
        return
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_metrics(load_path): ------ save_metrics에 저장해 둔 정보를 불러옵니다.
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.