코드 12-8 모델 학습
num_episodes = 50
for i_episode in range(num_episodes):
env.reset() ------ 환경과 상태 초기화
last_screen = get_screen()
current_screen = get_screen()
state = current_screen - last_screen
for t in count():
action = select_action(state) ------ 행동 선택 및 실행
_, reward, done, _ = env.step(action.item()) ------ 선택한 행동(action)을 환경으로 보냅니다.
reward = torch.tensor([reward], device=device)
last_screen = current_screen
current_screen = get_screen()
if not done: ------ 새로운 상태 관찰(observe)
next_state = current_screen - last_screen
else:
next_state = None
memory.push(state, action, next_state, reward) ------ 상태 전이(state transition)를 메모리에 저장
state = next_state ------ 다음 상태로 이동
optimize_model() ------ 타깃(큐) 네트워크에 대해 최적화 진행
if done:
episode_durations.append(t + 1)
break
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict()) ------ 큐 네트워크의 모든 가중치와 바이어스를 복사하여 타깃(큐) 네트워크를 업데이트합니다.
print('종료')
env.render() ------ 화면을 출력
env.close() ------ 화면을 종료
plt.show()