훈련과 검증 용도의 데이터셋을 정의합니다.
코드 6-71 이미지 데이터셋 정의
train_dataset = DogvsCatDataset(train_images_filepaths, transform=ImageTransform(size, mean, std), phase='train')
val_dataset = DogvsCatDataset(val_images_filepaths, transform=ImageTransform(size, mean, std), phase='val')
index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
다음은 훈련 데이터셋 index 0의 이미지 크기와 레이블에 대한 출력 결과입니다.
torch.Size([3, 224, 224]) 0
이미지는 컬러(채널 3) 상태에서 224×224 크기를 갖고 있으며 레이블이 0이므로 고양이를 의미합니다.
데이터로더를 이용하여 데이터를 메모리로 불러옵니다. 불러올 때는 배치 크기만큼 나누어서 불러옵니다.
코드 6-72 데이터셋의 데이터를 메모리로 불러오기
train_iterator = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_iterator = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataloader_dict = {'train': train_iterator, 'val': valid_iterator}
batch_iterator = iter(train_iterator)
inputs, label = next(batch_iterator)
print(inputs.size())
print(label)