코드 12-7 모델에서 사용할 옵티마이저 정의

    def optimize_model():
        if len(memory) < BATCH_SIZE:
            return
    
        transitions = memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions)) ------ ①
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) ------ ②
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) ------ torch.cat을 이용하여 s 값들을 이어 붙입니다.
    
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
    
        state_action_values = policy_net(state_batch).gather(1, action_batch) ------ Q(st,at)를 계산
        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() ------ ③
    
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch ------ V(st+1)을 계산
    
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) ------ ④
        optimizer.zero_grad()
        loss.backward()
    
        for param in policy_net.parameters():
            param.grad.data.clamp_(-1, 1)
        optimizer.step()
    신간 소식 구독하기
    뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.