버트 예제에서 사용할 데이터셋을 불러옵니다. 참고로 훈련과 테스트 용도의 데이터셋을 내려받았지만, 검증용 데이터셋을 위해 훈련 데이터셋을 임의로 나누었습니다.
코드 10-41 데이터셋 불러오기
train_df = pd.read_csv('../chap10/data/training.txt', sep='\t')
valid_df = pd.read_csv('../chap10/data/validing.txt', sep='\t')
test_df = pd.read_csv('../chap10/data/testing.txt', sep='\t')
모델 훈련을 위해 주어진 데이터셋의 10%만 사용합니다. 빠른 예제 처리를 위해 전체 데이터셋 중 일부만 사용하기 때문에 성능은 좋지 않을 수 있습니다. 컴퓨터 성능이 좋다면 전체 데이터셋을 모두 사용해도 좋습니다.
코드 10-42 불러온 데이터셋 중 일부만 사용
train_df = train_df.sample(frac=0.1, random_state=500)
valid_df = valid_df.sample(frac=0.1, random_state=500)
test_df = test_df.sample(frac=0.1, random_state=500)
주어진 데이터를 이용한 데이터셋(파이토치에서 사용하는 data.Dataset을 의미)을 생성하기 위한 함수를 정의합니다.
코드 10-43 데이터셋 생성
class Datasets(Dataset):
def __init__(self, df):
self.df = df
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
text = self.df.iloc[idx, 1] ------ ①
label = self.df.iloc[idx, 2]
return text, label