① 파이토치는 매 계층마다 print 문을 사용하지 않더라도 hook 기능을 사용하여 각 계층의 활성화 함수 및 기울기 값을 확인할 수 있습니다. 따라서 register_forward_hook의 목적은 순전파 중에 각 네트워크 모듈의 입력 및 출력을 가져오는 것입니다. 예를 들어 다음과 같은 코드가 있다고 합시다.
import torch x = torch.Tensor([0,1,2,3]).requires_grad_() y = torch.Tensor([4,5,6,7]).requires_grad_() w = torch.Tensor([1,2,3,4]).requires_grad_() z = x + y; o = w.matmul(z) o.backward() print(x.grad, y.grad, z.grad, w.grad, o.grad)
이 코드를 실행하면 다음과 같이 출력됩니다.
tensor([2., 3., 4., 5.]) tensor([2., 3., 4., 5.]) None tensor([ 4., 6., 8., 10.]) None
코드에서 o와 z는 특정한 값으로 정의되지 않은 중간 변수(계산 결과에 따라 값이 달라질 수 있는 변수)입니다. 파이토치는 이러한 변수에 대해서는 기울기 값을 저장하지 않습니다. 하지만 이와 같은 중간 변수에 대해 z.register_hook(hook_fn)을 사용하면 기울기 값을 알 수 있습니다. 이와 같이 hook을 이용하면 중간 결괏값들을 확인할 수 있습니다.