Skip to content
Open
78 changes: 56 additions & 22 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def unmask_messages(
)


def unmask_sample(
def unmask_sample_single(
sample: t.Dict[str, t.Any], tokenizer: PreTrainedTokenizer
) -> ProcessedMessagesData:
"""
Expand Down Expand Up @@ -793,6 +793,25 @@ def unmask_sample(
return unmask_messages(sample["messages"], tokenizer, unmask_roles)


def unmask_batch(
batch: t.Dict[str, t.List[t.Any]], tokenizer: PreTrainedTokenizer
) -> t.Dict[str, t.List[t.Any]]:
input_ids_list = []
labels_list = []

for i in range(len(batch["messages"])):
sample = {key: batch[key][i] for key in batch}
result = unmask_sample_single(sample, tokenizer)

input_ids_list.append(result["input_ids"])
labels_list.append(result["labels"])

return {
"input_ids": input_ids_list,
"labels": labels_list,
}


def extract_messages_from_pretraining_text(text: str) -> t.List[Message]:
"""
Given a message from a pretraining message that was formatted using either the generic
Expand Down Expand Up @@ -899,29 +918,37 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool:


def ensure_dataset_is_compatible_with_legacy_format(
sample: t.Dict[str, t.Any],
) -> t.Dict[str, t.Any]:
batch: t.Dict[str, t.List[t.Any]],
) -> t.Dict[str, t.List[t.Any]]:
"""
Given a sample that uses the legacy pre-training format, we unroll the samples into ones with the
original messages contents.
Given a batch of samples using the legacy pre-training format, unroll the samples into ones with
the original messages contents.
"""
# deepcopy to prevent re-referencing the existing objects
new_sample = {
"messages": [],
"unmask": sample.get("unmask", False),
}
for msg in sample["messages"]:
if msg["role"] != "pretraining":
new_sample["messages"].append(msg)
continue
processed_messages = []
unmask_flags = []

# handle unmasking
new_sample["messages"].extend(
extract_messages_from_pretraining_text(msg["content"])
)
new_sample["unmask"] = True
for messages, unmask_flag in zip(
batch["messages"], batch.get("unmask", [False] * len(batch["messages"]))
):
new_messages = []
unmask = unmask_flag

return new_sample
for msg in messages:
if msg["role"] != "pretraining":
new_messages.append(msg)
else:
new_messages.extend(
extract_messages_from_pretraining_text(msg["content"])
)
unmask = True # if any pretraining message is found, set unmask to True

processed_messages.append(new_messages)
unmask_flags.append(unmask)

return {
"messages": processed_messages,
"unmask": unmask_flags,
}


def filter_samples_by_length(
Expand Down Expand Up @@ -1051,6 +1078,8 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset:

return data.map(
ensure_dataset_is_compatible_with_legacy_format,
batched=True,
batch_size=1000,
num_proc=num_procs,
desc="Ensuring dataset is compatible with legacy format.",
)
Expand Down Expand Up @@ -1082,16 +1111,21 @@ def configure_tokenizer(model_path: str) -> PreTrainedTokenizer:


def process_samples(
data: Dataset, tokenizer: PreTrainedTokenizer, num_cpu_procs: int
data: Dataset,
tokenizer: PreTrainedTokenizer,
num_cpu_procs: int,
batch_size: int = 1000,
) -> Dataset:
"""Process samples to generate input_ids and labels."""

# Create a wrapper function for unmask_sample
process_sample_fn = partial(unmask_sample, tokenizer=tokenizer)
process_sample_fn = partial(unmask_batch, tokenizer=tokenizer)

# Process the dataset
processed_data = data.map(
process_sample_fn,
batched=True,
batch_size=batch_size,
num_proc=num_cpu_procs,
desc="Converting samples into input_ids and labels...",
load_from_cache_file=False,
Expand Down
63 changes: 62 additions & 1 deletion tests/unit/test_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import typing as t
import unittest

# Third Party
from datasets import Dataset

try:
# Third Party
import pytest
Expand All @@ -16,7 +19,11 @@

try:
# Third Party
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
LlamaTokenizerFast
)

TRANSFORMERS_AVAILABLE = True
except ImportError:
Expand All @@ -26,6 +33,7 @@
# First Party
from instructlab.training.data_process import (
MASK_TOKEN,
process_samples,
UNMASK_BEGIN_TOKEN,
UNMASK_END_TOKEN,
UNMASK_REASONING_BEGIN_TOKEN,
Expand Down Expand Up @@ -990,5 +998,58 @@ def test_edge_cases_with_reasoning_content(self):
)


@pytest.fixture(scope="module")
def tokenizer():
tokenizer = LlamaTokenizerFast.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")

# Ensure UNMASK tokens are treated atomically
tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|UNMASK_BEGIN|>", "<|UNMASK_END|>"]}
)

# Safety: add a pad token if it's missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or "</s>"

return tokenizer


def test_process_samples_outputs_input_ids_and_labels(tokenizer):
# Create a dummy dataset of 100 samples
messages = [
[
{"role": "user", "content": f"Hello {i}"},
{"role": "assistant", "content": f"Hi there {i}!"},
{"role": "pretraining", "content": f"Pretraining text {i}"},
]
for i in range(100)
]

unmask_flags = [True for _ in range(100)]

dummy_data = Dataset.from_dict(
{
"messages": messages,
"unmask": unmask_flags,
}
)

# Use realistic batch size
processed = process_samples(dummy_data, tokenizer, num_cpu_procs=1, batch_size=8)

# Check the structure
assert "input_ids" in processed.column_names
assert "labels" in processed.column_names
assert len(processed) == 100

# Check that input_ids and labels exist and match length for a few random samples
for i in [0, 25, 50, 99]:
sample = processed[i]
assert isinstance(sample["input_ids"], list)
assert isinstance(sample["labels"], list)
assert len(sample["input_ids"]) == len(sample["labels"])
assert all(isinstance(x, int) for x in sample["input_ids"])


if __name__ == "__main__":
unittest.main()
Loading