tfrec_path = os.path.join(GCS_DS_PATH, "train.tfrecord")
dataset = tf.data.TFRecordDataset(tfrec_path, num_parallel_reads=AUTO)
dataset = dataset.map(
lambda x: read_labeled_tfrecord(x, MAX_LEN), num_parallel_calls=AUTO
)
dataset = dataset.repeat()
dataset = dataset.shuffle(2048)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(AUTO)
>>> dataset
<PrefetchDataset shapes: (((None, 220), (None, 220)), (None,)),
types: ((tf.int32, tf.int32), tf.float32)>
한 가지 재미있는 부분은 학습이 시작되기 전까지는 스토리지에 있는 데이터셋의 경로와 메타데이터만 전달될 뿐, 캐글 노트북(컴퓨팅 인스턴스)의 메모리나 디스크를 전혀 점유하지 않는다는 점입니다. 실제로 TFRecord 파일을 불러오고, ‘tf.data.TFRecordDataset’을 구성할 때 캐글 노트북의 세션 현황을 보면 메모리 변화가 거의 나타나지 않습니다. 이는 모델 학습이 시작될 때도 마찬가지입니다. 학습이 시작되면 구글 클라우드 스토리지에서 TFRecord 데이터가 TPU로 직접 전송되는 것이지요. 이러한 이유 때문에 캐글 노트북의, 보통의 컴퓨팅 리소스로도 충분히 TPU 학습할 수 있습니다.