def read_labeled_tfrecord(example, max_len=220):
LABELED_TFREC_FORMAT = {
# tf.string means bytestring
"input_ids": tf.io.FixedLenFeature([], tf.string),
"attention_mask": tf.io.FixedLenFeature([], tf.string),
# shape [] means single element
"label": tf.io.FixedLenFeature([], tf.float32),
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
input_ids = tf.io.decode_raw(example["input_ids"], tf.uint16)
input_ids = tf.cast(input_ids, tf.int32)
input_ids = tf.reshape(input_ids, [max_len])
attention_mask = tf.io.decode_raw(example["attention_mask"], tf.bool)
attention_mask = tf.cast(attention_mask, tf.int32)
attention_mask = tf.reshape(attention_mask, [max_len])
label = example["label"]
return (input_ids, attention_mask), label
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 512
NUM_TRAIN = 1804874
STEPS_PER_EPOCH = NUM_TRAIN // BATCH_SIZE
MAX_LEN = 220