def serialize_train_example(input_ids, attention_mask, label):
feature = {
"input_ids": _bytes_feature(input_ids),
"attention_mask": _bytes_feature(attention_mask),
"label": _float_array_feature(label),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
def serialize_test_example(input_ids, attention_mask):
feature = {
"input_ids": _bytes_feature(input_ids),
"attention_mask": _bytes_feature(attention_mask),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
위 함수들은 TFRecord 변환을 위해 데이터를 이진 데이터 형태로 바꿔주는 함수입니다. 다음 TFRecordGenerator에서 위 함수를 사용해 텍스트를 이진 데이터로 변환합니다.