다음 그림은 CartPole의 화면에 대한 출력 결과입니다.
▲ 그림 12-16 CartPole 화면 예시
DQN을 policy_net, target_net이라는 이름으로 모델을 객체화하고 손실 함수를 정의해서 학습을 위한 준비를 합니다.
코드 12-6 모델 객체화 및 손실 함수 정의
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
init_screen = get_screen() ------ ①
_, _, screen_height, screen_width = init_screen.shape
n_actions = env.action_space.n ------ gym에서 행동(action)에 대한 횟수를 가져옵니다.
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict()) ------ ②
target_net.eval()
optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)
steps_done = 0
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1) ------ max(1)은 각 행의 가장 큰 열 값을 반환
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
episode_durations = []