이 예제에서는 특성 맵의 시각화에 대해 살펴볼 예정이므로 특성 맵의 결과를 확인할 수 있는 함수를 정의해야 합니다. 특성 맵은 합성곱층을 입력 이미지와 필터를 연산하여 얻은 결과입니다. 따라서 합성곱층에서 입력과 출력을 알 수 있다면 특성 맵에 대한 값들을 확인할 수 있다는 의미이기도 합니다. 예를 들어 코드 5-32의 출력 결과인 (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)에 대한 특성 맵을 확인하기 위한 클래스를 먼저 정의합니다.
코드 5-33 특성 맵을 확인하기 위한 클래스 정의
class LayerActivations:
features = []
def __init__(self, model, layer_num):
self.hook = model[layer_num].register_forward_hook(self.hook_fn) ------ ①
def hook_fn(self, module, input, output):
self.features = output.detach().numpy()
def remove(self): ------ hook 삭제
self.hook.remove()