@@ -589,7 +589,7 @@ def unmask_messages(
589
589
)
590
590
591
591
592
- def unmask_sample (
592
+ def unmask_sample_single (
593
593
sample : t .Dict [str , t .Any ], tokenizer : PreTrainedTokenizer
594
594
) -> ProcessedMessagesData :
595
595
"""
@@ -618,6 +618,25 @@ def unmask_sample(
618
618
return unmask_messages (sample ["messages" ], tokenizer , unmask_roles )
619
619
620
620
621
+ def unmask_sample (
622
+ batch : t .Dict [str , t .List [t .Any ]], tokenizer : PreTrainedTokenizer
623
+ ) -> t .Dict [str , t .List [t .Any ]]:
624
+ input_ids_list = []
625
+ labels_list = []
626
+
627
+ for i in range (len (batch ["messages" ])):
628
+ sample = {key : batch [key ][i ] for key in batch }
629
+ result = unmask_sample_single (sample , tokenizer )
630
+
631
+ input_ids_list .append (result ["input_ids" ])
632
+ labels_list .append (result ["labels" ])
633
+
634
+ return {
635
+ "input_ids" : input_ids_list ,
636
+ "labels" : labels_list ,
637
+ }
638
+
639
+
621
640
def extract_messages_from_pretraining_text (text : str ) -> t .List [Message ]:
622
641
"""
623
642
Given a message from a pretraining message that was formatted using either the generic
@@ -925,6 +944,8 @@ def process_samples(
925
944
# Process the dataset
926
945
processed_data = data .map (
927
946
process_sample_fn ,
947
+ batched = True ,
948
+ batch_size = 1000 ,
928
949
num_proc = num_cpu_procs ,
929
950
desc = "Converting samples into input_ids and labels..." ,
930
951
load_from_cache_file = False ,
0 commit comments