attention_mask = tf.cast(attention_mask, tf.int32) attention_mask = tf.reshape(attention_mask, [max_len]) return input_ids, attention_mask AUTO = tf.data.experimental.AUTOTUNE BATCH_SIZE = 64 MAX_LEN = 220 TFRECORD_PATH = "../input/fkms-jigsaw-tfrecord-files/" tfrec_path = os.path.join(TFRECORD_PATH, "test.tfrecord") dataset = tf.data.TFRecordDataset(tfrec_path, num_parallel_reads=AUTO) dataset = dataset.map( lambda x: read_non_labeled_tfrecord(x, MAX_LEN), num_parallel_calls=AUTO ) dataset = dataset.batch(BATCH_SIZE) dataset = dataset.prefetch(AUTO) >>> dataset <PrefetchDataset shapes: {input_ids: (None, 220), attention_mask: (None, 220)}, types: {input_ids: tf.int32, attention_mask: tf.int32}>
이제 학습에 사용한 모델 구조를 다시 불러오고 추론을 진행하겠습니다.