코드 8-22 데이터셋 가져오기
train_dataset = datasets.ImageFolder(
root=r'../chap08/data/archive/train',
transform=train_transform
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=32, shuffle=True,
)
val_dataset = datasets.ImageFolder(
root=r'../chap08/data/archive/test',
transform=val_transform
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=32, shuffle=False,
)
이제 모델을 생성할 텐데, 네트워크를 직접 구축하는 것이 아닌 사전 학습된 ResNet50을 사용할 예정입니다. 6장에서 배웠듯이 사전 학습된 모델을 사용할 경우 간편하게 네트워크를 구성하고 사용할 수 있는 장점이 있습니다.
코드 8-23 모델 생성
def resnet50(pretrained=True, requires_grad=False):
model = models.resnet50(progress=True, pretrained=pretrained)
if requires_grad == False: ------ 파라미터를 고정하여 backward() 중에 기울기가 계산되지 않도록 합니다. requires_grad=False를 파라미터로 받았기 때문에 해당 구문이 실행됩니다.
for param in model.parameters():
param.requires_grad = False
elif requires_grad == True: ------ 파라미터 값이 backward() 중에 기울기 계산에 반영됩니다.
for param in model.parameters():
param.requires_grad = True
model.fc = nn.Linear(2048, 2) ------ 마지막 분류를 위한 계층은 학습을 진행합니다.
return model