Skip to content

Commit 34dffde

Browse files
committed
Added batch mapping for the process_samples function in process_messages_into_input_ids.
Signed-off-by: aryanorpe <[email protected]>
1 parent c092c46 commit 34dffde

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

src/instructlab/training/data_process.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def unmask_messages(
589589
)
590590

591591

592-
def unmask_sample(
592+
def unmask_sample_single(
593593
sample: t.Dict[str, t.Any], tokenizer: PreTrainedTokenizer
594594
) -> ProcessedMessagesData:
595595
"""
@@ -618,6 +618,25 @@ def unmask_sample(
618618
return unmask_messages(sample["messages"], tokenizer, unmask_roles)
619619

620620

621+
def unmask_sample(
622+
batch: t.Dict[str, t.List[t.Any]], tokenizer: PreTrainedTokenizer
623+
) -> t.Dict[str, t.List[t.Any]]:
624+
input_ids_list = []
625+
labels_list = []
626+
627+
for i in range(len(batch["messages"])):
628+
sample = {key: batch[key][i] for key in batch}
629+
result = unmask_sample_single(sample, tokenizer)
630+
631+
input_ids_list.append(result["input_ids"])
632+
labels_list.append(result["labels"])
633+
634+
return {
635+
"input_ids": input_ids_list,
636+
"labels": labels_list,
637+
}
638+
639+
621640
def extract_messages_from_pretraining_text(text: str) -> t.List[Message]:
622641
"""
623642
Given a message from a pretraining message that was formatted using either the generic
@@ -925,6 +944,8 @@ def process_samples(
925944
# Process the dataset
926945
processed_data = data.map(
927946
process_sample_fn,
947+
batched=True,
948+
batch_size=1000,
928949
num_proc=num_cpu_procs,
929950
desc="Converting samples into input_ids and labels...",
930951
load_from_cache_file=False,

0 commit comments

Comments
 (0)