코드 6-9 데이터로더 정의
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) ------ ①
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataloader_dict = {'train': train_dataloader, 'val': val_dataloader} ------ 훈련 데이터셋(train_dataloader)과 검증 데이터셋(val_dataloader)을 합쳐서 표현
batch_iterator = iter(train_dataloader)
inputs, label = next(batch_iterator)
print(inputs.size())
print(label)
① 파이토치의 데이터로더는 배치 관리를 담당합니다. 한 번에 모든 데이터를 불러오면 메모리에 부담을 줄 수 있기 때문에 데이터를 그룹으로 쪼개서 조금씩 불러옵니다.
ⓐ 첫 번째 파라미터: 데이터를 불러오기 위한 데이터셋입니다.
ⓑ batch_size: 한 번에 메모리로 불러올 데이터 크기로, 여기에서는 32개씩 데이터를 가져옵니다.
ⓒ shuffle: 메모리로 데이터를 가져올 때 임의로 섞어서 가져오도록 합니다.
다음은 데이터로더를 이용하여 훈련 데이터셋을 메모리로 불러온 후 데이터셋의 크기와 레이블을 출력한 결과입니다.
torch.Size([32, 3, 224, 224]) tensor([1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0])