for i in range(chunk_size):
if self.dataset_type == "TRAIN":
sample = serialize_train_example(
input_ids[i].tobytes(),
attention_mask[i].tobytes(),
label_chunk[i],
)
else:
sample = serialize_test_example(
input_ids[i].tobytes(), attention_mask[i].tobytes()
)
writer.write(sample)
TFRecordGenerator는 텍스트 데이터를 TFRecord 데이터로 변환하고 저장하는 기능을 합니다. TFRecord를 만들기 전에 텍스트를 토큰으로 바꿔줄 토크나이저를 구성해야 합니다.
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
save_path = "./bert_base_uncased/"
os.makedirs(save_path, exist_ok=True)
tokenizer.save_pretrained(save_path)