더북(TheBook)

① 파이토치는 매 계층마다 print 문을 사용하지 않더라도 hook 기능을 사용하여 각 계층의 활성화 함수 및 기울기 값을 확인할 수 있습니다. 따라서 register_forward_hook의 목적은 순전파 중에 각 네트워크 모듈의 입력 및 출력을 가져오는 것입니다. 예를 들어 다음과 같은 코드가 있다고 합시다.

import torch
x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([4,5,6,7]).requires_grad_()
w = torch.Tensor([1,2,3,4]).requires_grad_()
z = x + y;
o = w.matmul(z)
o.backward()
print(x.grad, y.grad, z.grad, w.grad, o.grad)

이 코드를 실행하면 다음과 같이 출력됩니다.

tensor([2., 3., 4., 5.]) tensor([2., 3., 4., 5.]) None tensor([ 4., 6., 8., 10.]) None

코드에서 oz는 특정한 값으로 정의되지 않은 중간 변수(계산 결과에 따라 값이 달라질 수 있는 변수)입니다. 파이토치는 이러한 변수에 대해서는 기울기 값을 저장하지 않습니다. 하지만 이와 같은 중간 변수에 대해 z.register_hook(hook_fn)을 사용하면 기울기 값을 알 수 있습니다. 이와 같이 hook을 이용하면 중간 결괏값들을 확인할 수 있습니다.

신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.