더북(TheBook)
def read_labeled_tfrecord(example, max_len=220):
    LABELED_TFREC_FORMAT = {
        # tf.string means bytestring
        "input_ids": tf.io.FixedLenFeature([], tf.string),  
        "attention_mask": tf.io.FixedLenFeature([], tf.string),
        # shape [] means single element
        "label": tf.io.FixedLenFeature([], tf.float32),  
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    input_ids = tf.io.decode_raw(example["input_ids"], tf.uint16)
    input_ids = tf.cast(input_ids, tf.int32)
    input_ids = tf.reshape(input_ids, [max_len])

    attention_mask = tf.io.decode_raw(example["attention_mask"], tf.bool)
    attention_mask = tf.cast(attention_mask, tf.int32)
    attention_mask = tf.reshape(attention_mask, [max_len])
    label = example["label"]

    return (input_ids, attention_mask), label

AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 512
NUM_TRAIN = 1804874
STEPS_PER_EPOCH = NUM_TRAIN // BATCH_SIZE
MAX_LEN = 220
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.