더북(TheBook)

① 선형 계층의 입력은 합성곱층(Conv2d)의 출력과 입력 이미지의 크기에 따라 달라지므로 ①과 같이 계산해야 합니다.

이제 환경에서 이미지를 추출하고 처리하는 함수를 정의합니다. 이때 다양한 이미지 변환을 쉽게 처리할 수 있도록 torchvision 패키지를 사용합니다.

먼저 pyglet 패키지를 설치합니다.

> pip install pyglet

코드 12-5 이미지 추출 및 처리

import pyglet

resize = T.Compose([T.ToPILImage(),
                   T.Resize(40, interpolation=Image.CUBIC),
                   T.ToTensor()]) ------ 이미지 크기 및 텐서 변환

def get_cart_location(screen_width): ------ 카트의 위치 정보 가져오기
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0) ------ 카트의 중간(중앙) 위치

def get_screen(): ------ ①
    screen = env.render(mode='rgb_array').transpose((2, 0, 1)) ------ ②
    _, screen_height, screen_width = screen.shape
    screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
    view_width = int(screen_width * 0.6)
    cart_location = get_cart_location(screen_width)

    if cart_location < view_width // 2: ------ 카트는 출력 화면의 아래쪽 중앙에 존재하므로 화면의 위쪽과 아래쪽을 제거
        slice_range = slice(view_width) ------ ③
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
    screen = screen[:, :, slice_range] ------ 카트가 화면의 중앙에 위치하도록 가장자리를 제거
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255 ------ ④
    screen = torch.from_numpy(screen) ------ 텐서로 변환
    return resize(screen).unsqueeze(0).to(device) ------ 출력 크기 조정 및 배치 차원 추가하여 데이터는 (배치, 채널, 높이, 너비)의 형태를 갖습니다.

env.reset() ------ 환경을 초기화
plt.figure()
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none') ------ permute 함수는 transpose 함수처럼 차원을 바꾸어서 표현할 때 사용
plt.title('화면 예시')
plt.show()
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.