다음은 내려받은 ResNet18의 합성곱층을 사용하되 파라미터에 대해서는 학습을 하지 않도록 고정시킵니다.
코드 5-16 사전 훈련된 모델의 파라미터 학습 유무 지정
def set_parameter_requires_grad(model, feature_extracting=True):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False ------ ①
set_parameter_requires_grad(resnet18)
① 역전파 중 파라미터들에 대한 변화를 계산할 필요가 없음을 나타냅니다. 즉, 모델의 일부를 고정하고 나머지를 학습하고자 할 때 requires_grad = False로 설정합니다. 이때 모델의 일부는 합성곱층(convolutional layer)과 풀링(pooling)층을 의미합니다.
내려받은 ResNet18의 마지막 부분에 완전연결층을 추가합니다. 추가된 완전연결층은 개와 고양이 클래스를 분류하는 용도로 사용됩니다.
▲ 그림 5-36 ResNet18에 완전연결층 추가
코드 5-17 ResNet18에 완전연결층 추가
resnet18.fc = nn.Linear(512, 2) ------ 2는 클래스가 두 개라는 의미