그리고 nn.functional을 사용하는 예시 코드는 다음과 같습니다.
import torch.nn.functional as F inputs = torch.randn(64, 3, 244, 244) weight = torch.randn(64, 3, 3, 3) bias = torch.randn(64) outputs = F.conv2d(inputs, weight, bias, padding=1)
nn.Conv2d에서 input_channel과 output_channel을 사용해서 연산했다면 functional.conv2d는 입력(input)과 가중치(weight) 자체를 직접 넣어 줍니다. 이때 직접 넣어 준다는 의미는 가중치를 전달해야 할 때마다 가중치 값을 새로 정의해야 함을 의미합니다. 그 외에 채워야 하는 파라미터들은 nn.Conv2d와 비슷합니다.
다음은 nn.xx와 nn.functional.xx를 비교한 표입니다.
▼ 표 5-1 nn.xx와 nn.functional.xx의 사용 방법 비교
구분 |
nn.xx |
nn.functional.xx |
형태 |
nn.Conv2d: 클래스 nn.Module 클래스를 상속받아 사용 |
nn.functional.conv2d: 함수 def function (input)으로 정의된 순수한 함수 |
호출 방법 |
먼저 하이퍼파라미터를 전달한 후 함수 호출을 통해 데이터 전달 |
함수를 호출할 때 하이퍼파라미터, 데이터 전달 |
위치 |
nn.Sequential 내에 위치 |
nn.Sequential에 위치할 수 없음 |
파라미터 |
파라미터를 새로 정의할 필요 없음 |
가중치를 수동으로 전달해야 할 때마다 자체 가중치를 정의 |