diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index 3cee19f2..5a814fbc 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -764,7 +764,7 @@ def unmask_messages( ) -def unmask_sample( +def unmask_sample_single( sample: t.Dict[str, t.Any], tokenizer: PreTrainedTokenizer ) -> ProcessedMessagesData: """ @@ -793,6 +793,25 @@ def unmask_sample( return unmask_messages(sample["messages"], tokenizer, unmask_roles) +def unmask_batch( + batch: t.Dict[str, t.List[t.Any]], tokenizer: PreTrainedTokenizer +) -> t.Dict[str, t.List[t.Any]]: + input_ids_list = [] + labels_list = [] + + for i in range(len(batch["messages"])): + sample = {key: batch[key][i] for key in batch} + result = unmask_sample_single(sample, tokenizer) + + input_ids_list.append(result["input_ids"]) + labels_list.append(result["labels"]) + + return { + "input_ids": input_ids_list, + "labels": labels_list, + } + + def extract_messages_from_pretraining_text(text: str) -> t.List[Message]: """ Given a message from a pretraining message that was formatted using either the generic @@ -899,29 +918,37 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool: def ensure_dataset_is_compatible_with_legacy_format( - sample: t.Dict[str, t.Any], -) -> t.Dict[str, t.Any]: + batch: t.Dict[str, t.List[t.Any]], +) -> t.Dict[str, t.List[t.Any]]: """ - Given a sample that uses the legacy pre-training format, we unroll the samples into ones with the - original messages contents. + Given a batch of samples using the legacy pre-training format, unroll the samples into ones with + the original messages contents. """ - # deepcopy to prevent re-referencing the existing objects - new_sample = { - "messages": [], - "unmask": sample.get("unmask", False), - } - for msg in sample["messages"]: - if msg["role"] != "pretraining": - new_sample["messages"].append(msg) - continue + processed_messages = [] + unmask_flags = [] - # handle unmasking - new_sample["messages"].extend( - extract_messages_from_pretraining_text(msg["content"]) - ) - new_sample["unmask"] = True + for messages, unmask_flag in zip( + batch["messages"], batch.get("unmask", [False] * len(batch["messages"])) + ): + new_messages = [] + unmask = unmask_flag - return new_sample + for msg in messages: + if msg["role"] != "pretraining": + new_messages.append(msg) + else: + new_messages.extend( + extract_messages_from_pretraining_text(msg["content"]) + ) + unmask = True # if any pretraining message is found, set unmask to True + + processed_messages.append(new_messages) + unmask_flags.append(unmask) + + return { + "messages": processed_messages, + "unmask": unmask_flags, + } def filter_samples_by_length( @@ -1051,6 +1078,8 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset: return data.map( ensure_dataset_is_compatible_with_legacy_format, + batched=True, + batch_size=1000, num_proc=num_procs, desc="Ensuring dataset is compatible with legacy format.", ) @@ -1082,16 +1111,21 @@ def configure_tokenizer(model_path: str) -> PreTrainedTokenizer: def process_samples( - data: Dataset, tokenizer: PreTrainedTokenizer, num_cpu_procs: int + data: Dataset, + tokenizer: PreTrainedTokenizer, + num_cpu_procs: int, + batch_size: int = 1000, ) -> Dataset: """Process samples to generate input_ids and labels.""" # Create a wrapper function for unmask_sample - process_sample_fn = partial(unmask_sample, tokenizer=tokenizer) + process_sample_fn = partial(unmask_batch, tokenizer=tokenizer) # Process the dataset processed_data = data.map( process_sample_fn, + batched=True, + batch_size=batch_size, num_proc=num_cpu_procs, desc="Converting samples into input_ids and labels...", load_from_cache_file=False, diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py index ba0bc4cc..9cffec0d 100644 --- a/tests/unit/test_data_process.py +++ b/tests/unit/test_data_process.py @@ -6,6 +6,9 @@ import typing as t import unittest +# Third Party +from datasets import Dataset + try: # Third Party import pytest @@ -16,7 +19,11 @@ try: # Third Party - from transformers import AutoTokenizer, PreTrainedTokenizer + from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + LlamaTokenizerFast + ) TRANSFORMERS_AVAILABLE = True except ImportError: @@ -26,6 +33,7 @@ # First Party from instructlab.training.data_process import ( MASK_TOKEN, + process_samples, UNMASK_BEGIN_TOKEN, UNMASK_END_TOKEN, UNMASK_REASONING_BEGIN_TOKEN, @@ -990,5 +998,58 @@ def test_edge_cases_with_reasoning_content(self): ) +@pytest.fixture(scope="module") +def tokenizer(): + tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceH4/zephyr-7b-alpha") + + # Ensure UNMASK tokens are treated atomically + tokenizer.add_special_tokens( + {"additional_special_tokens": ["<|UNMASK_BEGIN|>", "<|UNMASK_END|>"]} + ) + + # Safety: add a pad token if it's missing + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token or "" + + return tokenizer + + +def test_process_samples_outputs_input_ids_and_labels(tokenizer): + # Create a dummy dataset of 100 samples + messages = [ + [ + {"role": "user", "content": f"Hello {i}"}, + {"role": "assistant", "content": f"Hi there {i}!"}, + {"role": "pretraining", "content": f"Pretraining text {i}"}, + ] + for i in range(100) + ] + + unmask_flags = [True for _ in range(100)] + + dummy_data = Dataset.from_dict( + { + "messages": messages, + "unmask": unmask_flags, + } + ) + + # Use realistic batch size + processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=8) + + # Check the structure + assert "input_ids" in processed.column_names + assert "labels" in processed.column_names + assert len(processed) == 100 + + # Check that input_ids and labels exist and match length for a few random samples + for i in [0, 25, 50, 99]: + sample = processed[i] + assert isinstance(sample["input_ids"], list) + assert isinstance(sample["labels"], list) + assert len(sample["input_ids"]) == len(sample["labels"]) + assert all(isinstance(x, int) for x in sample["input_ids"]) + + if __name__ == "__main__": unittest.main()