더북(TheBook)

코드 8-22 데이터셋 가져오기

train_dataset = datasets.ImageFolder(
    root=r'../chap08/data/archive/train',
    transform=train_transform
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True,
)
val_dataset = datasets.ImageFolder(
    root=r'../chap08/data/archive/test',
    transform=val_transform
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, shuffle=False,
)

이제 모델을 생성할 텐데, 네트워크를 직접 구축하는 것이 아닌 사전 학습된 ResNet50을 사용할 예정입니다. 6장에서 배웠듯이 사전 학습된 모델을 사용할 경우 간편하게 네트워크를 구성하고 사용할 수 있는 장점이 있습니다.

코드 8-23 모델 생성

def resnet50(pretrained=True, requires_grad=False):
    model = models.resnet50(progress=True, pretrained=pretrained)
    if requires_grad == False: ------ 파라미터를 고정하여 backward() 중에 기울기가 계산되지 않도록 합니다. requires_grad=False를 파라미터로 받았기 때문에 해당 구문이 실행됩니다.
        for param in model.parameters():
            param.requires_grad = False
    elif requires_grad == True: ------ 파라미터 값이 backward() 중에 기울기 계산에 반영됩니다.
        for param in model.parameters():
            param.requires_grad = True
    model.fc = nn.Linear(2048, 2) ------ 마지막 분류를 위한 계층은 학습을 진행합니다.
    return model
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.