AlexNet은 파라미터를 6000만 개 사용하는 모델입니다. 이때 충분한 데이터가 없으면 과적합이 발생하는 등 테스트 데이터에 대한 성능이 좋지 않습니다. 우리가 사용할 예제에서는 데이터셋을 상당히 제한하여 사용하고 있기 때문에 성능은 좋지 않다는 것을 미리 언급합니다. 성능이 좋은 결과를 원한다면 충분한 데이터셋을 확보하고 테스트를 진행하면 됩니다. 예를 들어 캐글에서 내려받은 모든 이미지를 사용하는 것뿐만 아니라 전처리 부분에서 데이터를 많이 확장(RandomRotation, RandomHorizontalFlip 등을 이용)시켜 예제를 진행해야 합니다.
torch.utils.data.Dataset을 상속받아 커스텀 데이터셋(custom dataset)을 정의합니다. torch.utils.data.Dataset 클래스를 상속받아 커스텀 데이터셋을 만들어 보겠습니다.
코드 6-25 커스텀 데이터셋 정의
class DogvsCatDataset(Dataset):
def __init__(self, file_list, transform=None, phase='train'):
self.file_list = file_list ------ 이미지 데이터가 위치한 파일 경로
self.transform = transform ------ 이미지 데이터 전처리
self.phase = phase ------ self.phase는 ImageTransform()에서 정의한 ‘train’과 ‘val’을 의미
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
img_path = self.file_list[idx] ------ 이미지 데이터의 인덱스를 가져오기
img = Image.open(img_path)
img_transformed = self.transform(img, self.phase)
label = img_path.split('/')[-1].split('.')[0] ------ 레이블 값을 가져오기
if label == 'dog':
label = 1
elif label == 'cat':
label = 0
return img_transformed, label ------ 전처리가 적용된 이미지와 레이블 반환