더북(TheBook)

다음 그림은 CartPole의 화면에 대한 출력 결과입니다.

▲ 그림 12-16 CartPole 화면 예시

DQN을 policy_net, target_net이라는 이름으로 모델을 객체화하고 손실 함수를 정의해서 학습을 위한 준비를 합니다.

코드 12-6 모델 객체화 및 손실 함수 정의

BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
init_screen = get_screen() ------ ①
_, _, screen_height, screen_width = init_screen.shape
n_actions = env.action_space.n ------ gym에서 행동(action)에 대한 횟수를 가져옵니다.

policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict()) ------ ②
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())
memory = ReplayMemory(10000)

steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1) ------ max(1)은 각 행의 가장 큰 열 값을 반환
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device,  dtype=torch.long)

episode_durations = []
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.