데이터에 대한 전처리를 합니다. 평균과 표준편차에 맞게 데이터를 정규화하기 위한 코드입니다.
코드 7-30 데이터 전처리
import torchvision.transforms as transforms
mnist_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,)) ------ 평균을 0.5, 표준편차를 1.0으로 데이터 정규화(데이터 분포를 조정)
])
torchvision.datasets에서 제공하는 데이터셋 중 MNIST 데이터셋을 내려받습니다.
코드 7-31 데이터셋 내려받기
from torchvision.datasets import MNIST
download_root = '../chap07/MNIST_DATASET' ------ MNIST를 내려받을 경로
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True) ------ ①
valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)