diff --git a/examples/recipes/llama/pretrain_llama3_8b.py b/examples/recipes/llama/pretrain_llama3_8b.py index 9757d747b..76ebde762 100644 --- a/examples/recipes/llama/pretrain_llama3_8b.py +++ b/examples/recipes/llama/pretrain_llama3_8b.py @@ -21,7 +21,7 @@ Examples: Basic usage with default configuration: - $ torchrun --nproc_per_node=8 pretrain_llama3_8b.py + $ torchrun --nproc_per_node=8 examples/recipes/llama/pretrain_llama3_8b.py Using a custom YAML config file: $ torchrun --nproc_per_node=8 pretrain_llama3_8b.py --config-file my_custom_config.yaml diff --git a/examples/recipes/qwen_vl/conf/qwen25_vl_pretrain_override_example.yaml b/examples/recipes/qwen_vl/conf/qwen25_vl_pretrain_override_example.yaml new file mode 100644 index 000000000..c2fbc78af --- /dev/null +++ b/examples/recipes/qwen_vl/conf/qwen25_vl_pretrain_override_example.yaml @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example override file for Qwen2.5-VL + +model: + seq_length: 4096 + +train: + train_iters: 20 + global_batch_size: 8 + micro_batch_size: 1 + eval_iters: 5 + +optimizer: + lr: 0.00025 + min_lr: 0.000025 + +scheduler: + lr_warmup_iters: 10 + +checkpoint: + # Directory to save to. If null, no checkpoint will be saved. + save: null + +dist: + use_megatron_fsdp: false + use_torch_fsdp2: false + +logger: + log_interval: 1 + +dataset: + sequence_length: 4096 + +rng: + seed: 42 + +ddp: + grad_reduce_in_fp32: true + + diff --git a/examples/recipes/qwen_vl/finetune_qwen25_vl.py b/examples/recipes/qwen_vl/finetune_qwen25_vl.py new file mode 100644 index 000000000..0c5501471 --- /dev/null +++ b/examples/recipes/qwen_vl/finetune_qwen25_vl.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen2.5-VL Finetuning Script with YAML and CLI Configuration Overrides. + +This mirrors the Llama example flow and uses the Qwen-VL recipe helpers. +You can pick a specific recipe via `--recipe`, e.g., `qwen25_vl_3b_finetune_config`, +`qwen25_vl_7b_finetune_config`, etc. + +Examples: + Loading pretrained weights (recommended for finetune): + 1) Import HF checkpoint to Megatron format: + $ python examples/conversion/convert_checkpoints.py import \ + --hf-model Qwen/Qwen2.5-VL-3B-Instruct \ + --megatron-path /path/to/megatron_ckpt + + 2) Run finetune using the imported checkpoint: + $ torchrun --nproc_per_node=8 examples/recipes/qwen_vl/finetune_qwen25_vl.py \ + --pretrained-checkpoint /path/to/megatron_ckpt + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 finetune_qwen25_vl.py --config-file conf/qwen25_vl_pretrain_override_example.yaml + + CLI overrides: + $ torchrun --nproc_per_node=8 finetune_qwen25_vl.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Selecting a specific recipe: + $ torchrun --nproc_per_node=8 finetune_qwen25_vl.py --recipe qwen25_vl_7b_finetune_config +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from megatron.bridge.recipes.qwen_vl import qwen25_vl as qwen_vl_recipes +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.training.vlm_step import forward_step +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "qwen25_vl_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse known script args and return remaining as Hydra-style overrides.""" + parser = argparse.ArgumentParser( + description="Finetune Qwen2.5-VL with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/qwen25_vl_pretrain_override_example.yaml", + ) + parser.add_argument( + "--data-path", + type=str, + default=None, + help="Path to JSON/JSONL dataset (preloaded conversation or legacy messages format).", + ) + parser.add_argument( + "--image-folder", + type=str, + default=None, + help="Optional root for resolving relative image/video paths in dataset records.", + ) + parser.add_argument( + "--dataset-type", + type=str, + choices=["mock", "preloaded", "hf"], + default=None, + help=( + "Dataset type to use: 'mock', 'preloaded', or 'hf'. " + "If not set, auto-detects based on --data-path/--use-preloaded." + ), + ) + parser.add_argument( + "--recipe", + type=str, + default="qwen25_vl_3b_finetune_config", + help=( + "Name of the recipe function in megatron.bridge.recipes.qwen_vl.qwen25_vl to use, " + "e.g., qwen25_vl_3b_finetune_config, qwen25_vl_7b_finetune_config." + ), + ) + parser.add_argument( + "--pretrained-checkpoint", + type=str, + default=None, + help=( + "Path to imported Megatron checkpoint directory to load before finetuning. " + "Generate it with scripts/import_hf_ckpt.py." + ), + ) + parser.add_argument( + "--use-preloaded", + action="store_true", + help="Use preloaded dataset provider (enabled automatically when --data-path is set).", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Load the base VLM recipe config, apply YAML/CLI overrides, and start pretraining. + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge Qwen2.5-VL Finetuning Script with YAML & CLI Overrides") + logger.info("-----------------------------------------------------------------------") + + # Resolve the recipe function from the provided name + recipe_name = getattr(args, "recipe", "qwen25_vl_3b_finetune_config") + available_recipes = [name for name in dir(qwen_vl_recipes) if name.endswith("_finetune_config")] + if not hasattr(qwen_vl_recipes, recipe_name): + logger.error( + "Unknown recipe '%s'. Available recipes: %s", + recipe_name, + ", ".join(sorted(available_recipes)), + ) + sys.exit(2) + pretrain_config = getattr(qwen_vl_recipes, recipe_name) + + # Determine dataset type based on CLI flag (overrides) or fall back to auto-detect + use_preloaded_flag = bool(args.data_path) or bool(getattr(args, "use_preloaded", False)) + dataset_type = args.dataset_type or ("preloaded" if use_preloaded_flag else "mock") + + cfg: ConfigContainer = pretrain_config( + dataset_type=dataset_type, + train_data_path=args.data_path, + valid_data_path=None, + test_data_path=None, + image_folder=args.image_folder, + pretrained_checkpoint=args.pretrained_checkpoint, + ) + logger.info("Loaded base configuration") + + if get_rank_safe() == 0: + cfg.print_yaml() + + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + + pretrain(config=cfg, forward_step_func=forward_step) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 12259c989..c011fa384 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "tqdm>=4.67.1", "hydra-core>1.3,<=1.3.2", "megatron-core[dev,mlm]>=0.14.0a0,<0.16.0", + "qwen-vl-utils", ] diff --git a/src/megatron/bridge/data/vlm_datasets/__init__.py b/src/megatron/bridge/data/vlm_datasets/__init__.py new file mode 100644 index 000000000..054de6537 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VLM dataset utilities. + +Public API re-exports: +- Makers: functions to build conversation examples from HF datasets +- Providers: classes that build PyTorch datasets bound to HF processors +- Collate fns: model-specific batch builders +""" + +from megatron.bridge.data.vlm_datasets.collate import ( + COLLATE_FNS, + default_collate_fn, + phi4_mm_collate_fn, + qwen2_5_collate_fn, +) +from megatron.bridge.data.vlm_datasets.conversation_dataset import VLMConversationDataset +from megatron.bridge.data.vlm_datasets.hf_dataset_makers import ( + make_cord_v2_dataset, + make_cv17_dataset, + make_medpix_dataset, + make_rdr_dataset, +) +from megatron.bridge.data.vlm_datasets.hf_provider import HFDatasetConversationProvider +from megatron.bridge.data.vlm_datasets.mock_provider import MockVLMConversationProvider +from megatron.bridge.data.vlm_datasets.preloaded_provider import PreloadedVLMConversationProvider + + +__all__ = [ + # Makers + "make_rdr_dataset", + "make_cord_v2_dataset", + "make_medpix_dataset", + "make_cv17_dataset", + # Dataset types/providers + "VLMConversationDataset", + "HFDatasetConversationProvider", + "PreloadedVLMConversationProvider", + "MockVLMConversationProvider", + # Collation utilities + "COLLATE_FNS", + "default_collate_fn", + "qwen2_5_collate_fn", + "phi4_mm_collate_fn", +] diff --git a/src/megatron/bridge/data/vlm_datasets/collate.py b/src/megatron/bridge/data/vlm_datasets/collate.py new file mode 100644 index 000000000..e0079f776 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/collate.py @@ -0,0 +1,317 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Collation utilities for building VLM training batches from conversation examples. +""" + +import torch +import torch.nn.functional as F +from PIL import Image # noqa: F401 # may be used downstream by processors + +from megatron.bridge.data.vlm_datasets.token_utils import extract_skipped_token_ids +from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs + + +# Local message used when optional qwen_vl_utils dependency is missing +MISSING_QWEN_VL_UTILS_MSG = ( + "qwen_vl_utils is required for Qwen2.5 VL processing. Please `pip install qwen-vl-utils` or" + " provide compatible vision preprocessing." +) + +try: + from qwen_vl_utils import process_vision_info + + HAVE_QWEN_VL_UTILS = True +except ImportError: + HAVE_QWEN_VL_UTILS = False + + +def _gather_assistant_text_segments(example: dict) -> list[str]: + """Extract assistant text segments from the structured conversation example. + + The example schema is expected to be {"conversation": [{"role": ..., "content": [...]} ...]} where + content is a list of items like {"type": "text"|"image"|..., "text": "..."}. + Returns a list of concatenated text strings, one per assistant turn. + """ + texts: list[str] = [] + for turn in example.get("conversation", []): + if turn.get("role") != "assistant": + continue + parts = turn.get("content", []) + buf = [] + if isinstance(parts, list): + for p in parts: + if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str): + buf.append(p["text"]) + elif isinstance(parts, str): + buf.append(parts) + if buf: + texts.append("".join(buf)) + return texts + + +def create_multiturn_loss_mask_by_search( + example: dict, input_ids, processor, skipped_tokens: torch.Tensor +) -> list[int]: + """Tokenizer-agnostic masking via substring search of assistant texts. + + - Tokenize full conversation with processor already done -> input_ids + - Extract assistant text strings from the structured example + - For each assistant text, tokenize without special tokens and search sequentially + - On success, unmask that span; otherwise leave masked + """ + tokenizer = getattr(processor, "tokenizer", processor) + ids = input_ids.tolist() + mask = [0] * len(ids) + + def try_mark(span_text: str, start_from: int) -> int: + """Tokenize a span and mark its occurrence if found. Returns new search start index.""" + variants = [span_text, span_text + "\n"] + for text in variants: + span_tokens = tokenizer(text, add_special_tokens=False)["input_ids"] + if not span_tokens: + continue + # naive sequential search from start_from + for i in range(start_from, len(ids) - len(span_tokens) + 1): + if ids[i : i + len(span_tokens)] == span_tokens: + for j in range(i, i + len(span_tokens)): + mask[j] = 1 + return i + len(span_tokens) + return start_from + + search_start = 0 + for asst_text in _gather_assistant_text_segments(example): + search_start = try_mark(asst_text, search_start) + + # Ensure pad/skipped tokens are masked + ids_t = torch.tensor(ids) + for k, t in enumerate(ids_t): + if t in skipped_tokens: + mask[k] = 0 + return mask + + +def phi4_mm_collate_fn(examples, processor): + """Collate function for Phi-4 MM model audio input""" + + # Extract conversations and audio data + conversations = [example["conversation"] for example in examples] + audios = [example["audio"] for example in examples] + texts = [processor.apply_chat_template(conversation, tokenize=False) for conversation in conversations] + audio_inputs = [(audio["array"], audio["sampling_rate"]) if isinstance(audio, dict) else audio for audio in audios] + batch = processor( + text=texts, audios=audio_inputs, return_tensors="pt", padding=True, truncation=True, max_length=1024 + ) + labels = batch["input_ids"].clone()[:, 1:] + labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) + + loss_masks = [] + for i, conversation in enumerate(conversations): + input_ids = batch["input_ids"][i].tolist() + + assistant_content = conversation[1]["content"] + assistant_tokens = processor.tokenizer(assistant_content, add_special_tokens=False)["input_ids"] + + loss_mask = [0] * len(input_ids) + for start_idx in range(len(input_ids) - len(assistant_tokens) + 1): + if input_ids[start_idx : start_idx + len(assistant_tokens)] == assistant_tokens: + for j in range(len(assistant_tokens)): + loss_mask[start_idx + j] = 1 + break + loss_masks.append(loss_mask) + + max_len = max(len(mask) for mask in loss_masks) + padded_loss_masks = [mask + [0] * (max_len - len(mask)) for mask in loss_masks] + batch["loss_mask"] = torch.tensor(padded_loss_masks, dtype=torch.float) + + labels[batch["loss_mask"] == 0] = -100 + batch["labels"] = labels + + # Remove specified batch features if present + for key in ["input_image_embeds", "image_sizes", "image_attention_mask"]: + if key in batch: + del batch[key] + return batch + + +def qwen2_5_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]: + """Collate function for Qwen2.5 VL model.""" + if not HAVE_QWEN_VL_UTILS: + raise ImportError(MISSING_QWEN_VL_UTILS_MSG) + + skipped_tokens = extract_skipped_token_ids(processor) + + texts = [processor.apply_chat_template(example["conversation"], tokenize=False) for example in examples] + # Build per-example images (list) and split by presence + per_example_images = [] + has_images = [] + for example in examples: + imgs = process_vision_info(example["conversation"])[0] + if imgs is None: + imgs = [] + elif not isinstance(imgs, list): + imgs = [imgs] + per_example_images.append(imgs) + has_images.append(len(imgs) > 0) + + idx_with = [i for i, h in enumerate(has_images) if h] + idx_without = [i for i, h in enumerate(has_images) if not h] + + batch_with = None + batch_without = None + + if idx_with: + texts_with = [texts[i] for i in idx_with] + images_with = [per_example_images[i] for i in idx_with] + batch_with = processor( + text=texts_with, + images=images_with, + padding=True, + return_tensors="pt", + min_pixels=200704, # 256*28*28 + max_pixels=1003520, # 1280*28*28 + ) + + if idx_without: + texts_without = [texts[i] for i in idx_without] + batch_without = processor( + text=texts_without, + padding=True, + return_tensors="pt", + ) + + # Merge batches back to original order + if batch_with is not None and batch_without is None: + batch = batch_with + elif batch_with is None and batch_without is not None: + batch = batch_without + else: + # Both exist: pad to common max length and interleave rows + pad_id = getattr(processor.tokenizer, "pad_token_id", 0) or 0 + in_with = batch_with["input_ids"] + in_without = batch_without["input_ids"] + max_len = max(in_with.shape[1], in_without.shape[1]) + + def pad_to(x, tgt_len): + if x.shape[1] == tgt_len: + return x + pad_len = tgt_len - x.shape[1] + return F.pad(x, (0, pad_len), value=pad_id) + + in_with = pad_to(in_with, max_len) + in_without = pad_to(in_without, max_len) + + input_ids = torch.full((len(examples), max_len), pad_id, dtype=in_with.dtype) + # Place rows + for row, i in enumerate(idx_with): + input_ids[i] = in_with[row] + for row, i in enumerate(idx_without): + input_ids[i] = in_without[row] + + batch = {"input_ids": input_ids} + # Carry over vision tensors if present + if "pixel_values" in batch_with: + batch["pixel_values"] = batch_with["pixel_values"] + if "image_grid_thw" in batch_with: + batch["image_grid_thw"] = batch_with["image_grid_thw"] + + labels = batch["input_ids"].clone()[:, 1:] + labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) + labels[torch.isin(labels, skipped_tokens)] = -100 + batch["labels"] = labels + # Ensure position_ids exist for the model + if "position_ids" not in batch: + batch_size, seq_len = batch["input_ids"].shape + batch["position_ids"] = ( + torch.arange(seq_len, device=batch["input_ids"].device).unsqueeze(0).expand(batch_size, -1) + ) + # Prefer general search-based masking using structured example content (not template-specific) + loss_masks = [ + create_multiturn_loss_mask_by_search(example, input_ids, processor, skipped_tokens) + for example, input_ids in zip(examples, batch["input_ids"]) # type: ignore[arg-type] + ] + loss_mask_t = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device) + # Shift loss mask to align with next-token labels timeline + loss_mask_t = torch.cat([loss_mask_t[:, 1:], torch.zeros_like(loss_mask_t[:, :1])], dim=1) + # Enforce label masking to match shifted loss_mask + batch["labels"] = batch["labels"].masked_fill(loss_mask_t == 0, -100) + batch["loss_mask"] = loss_mask_t + # Build Qwen2VL visual inputs object and attach to batch; remove raw keys + visual_inputs = Qwen2_5_VLVisualInputs( + pixel_values=batch.get("pixel_values"), + image_grid_thw=batch.get("image_grid_thw"), + ) + if "pixel_values" in batch: + del batch["pixel_values"] + if "image_grid_thw" in batch: + del batch["image_grid_thw"] + batch["visual_inputs"] = visual_inputs + return batch + + +def default_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]: + """Default collate function for VLM models.""" + if not HAVE_QWEN_VL_UTILS: + raise ImportError(MISSING_QWEN_VL_UTILS_MSG) + + skipped_tokens = extract_skipped_token_ids(processor) + + batch = processor.apply_chat_template( + [example["conversation"] for example in examples], + tokenize=True, + padding=True, + truncation=True, + return_tensors="pt", + return_dict=True, + ) + + if "position_ids" not in batch: + batch_size, seq_len = batch["input_ids"].shape + batch["position_ids"] = ( + torch.arange(seq_len, device=batch["input_ids"].device).unsqueeze(0).expand(batch_size, -1) + ) + + batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16) + labels = batch["input_ids"].clone()[:, 1:] + labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) + labels[torch.isin(labels, skipped_tokens)] = -100 + batch["labels"] = labels + loss_masks = [ + create_multiturn_loss_mask_by_search(example, input_ids, processor, skipped_tokens) + for example, input_ids in zip(examples, batch["input_ids"]) # type: ignore[arg-type] + ] + loss_mask_t = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device) + # Shift loss mask to align with next-token labels timeline + loss_mask_t = torch.cat([loss_mask_t[:, 1:], torch.zeros_like(loss_mask_t[:, :1])], dim=1) + batch["labels"] = batch["labels"].masked_fill(loss_mask_t == 0, -100) + batch["loss_mask"] = loss_mask_t + # Build Qwen2VL visual inputs object and attach to batch; remove raw keys + visual_inputs = Qwen2_5_VLVisualInputs( + pixel_values=batch.get("pixel_values"), + image_grid_thw=batch.get("image_grid_thw"), + ) + if "pixel_values" in batch: + del batch["pixel_values"] + if "image_grid_thw" in batch: + del batch["image_grid_thw"] + batch["visual_inputs"] = visual_inputs + return batch + + +# Mapping of processor types to their collate functions +COLLATE_FNS = { + "Qwen2_5_VLProcessor": qwen2_5_collate_fn, + "default": default_collate_fn, +} diff --git a/src/megatron/bridge/data/vlm_datasets/conversation_dataset.py b/src/megatron/bridge/data/vlm_datasets/conversation_dataset.py new file mode 100644 index 000000000..7157118b5 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/conversation_dataset.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Core dataset types for conversation-style VLM examples. +""" + +from typing import Any, Callable, Dict, List, Optional + +import torch + +from megatron.bridge.data.vlm_datasets.collate import COLLATE_FNS + + +class VLMConversationDataset(torch.utils.data.Dataset): + """Repeating wrapper over a list of HF-style conversation examples. + + - Each base example is expected to contain a "conversation" key following + processor.apply_chat_template conventions. Optional modality fields like + "audio" are passed through and consumed by the collate function. + - Dataset length is set to a target length and indexes wrap around the + underlying list to meet the requested size. + - A `collate_fn` attribute is exposed so the framework can pass it to the + DataLoader. + """ + + def __init__( + self, + base_examples: List[Dict[str, Any]], + target_length: int, + processor: Any, + collate_impl: Optional[Callable[[list, Any], Dict[str, torch.Tensor]]] = None, + ) -> None: + assert isinstance(base_examples, list) and len(base_examples) > 0, "base_examples must be a non-empty list" + self._base_examples = base_examples + self._length = int(max(0, target_length)) + self._processor = processor + # Choose collate implementation by processor type name when not provided + collate_key = type(processor).__name__ if processor is not None else "default" + selected_impl = collate_impl or COLLATE_FNS.get(collate_key, COLLATE_FNS["default"]) # type: ignore[index] + + def _bound_collate(batch: list) -> Dict[str, torch.Tensor]: + return selected_impl(batch, self._processor) # type: ignore[call-arg] + + self.collate_fn = _bound_collate + + def __len__(self) -> int: + return self._length + + def __getitem__(self, idx: int) -> Dict[str, Any]: + if self._length == 0: + raise IndexError("Empty dataset") + base = self._base_examples[idx % len(self._base_examples)] + return base diff --git a/src/megatron/bridge/data/vlm_datasets/hf_dataset_makers.py b/src/megatron/bridge/data/vlm_datasets/hf_dataset_makers.py new file mode 100644 index 000000000..782406156 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/hf_dataset_makers.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Built-in maker functions that transform HuggingFace datasets into +conversation-style examples consumable by VLM processors. +""" + +import json +import random +from typing import Any, Dict, List + +from datasets import load_dataset + +from megatron.bridge.data.vlm_datasets.token_utils import json2token + + +def make_rdr_dataset( + path_or_dataset: str = "quintend/rdr-items", split: str = "train", **kwargs +) -> List[Dict[str, Any]]: + """Load and preprocess the RDR dataset for image-to-text fine-tuning. + + Returns a list of examples with a "conversation" field that includes an image and text. + """ + dataset = load_dataset(path_or_dataset, split=split) + + def format(example): + return { + "conversation": [ + { + "role": "user", + "content": [ + {"type": "image", "image": example["image"]}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": example["text"]}], + }, + ], + } + + return [format(example) for example in dataset] + + +def make_cord_v2_dataset( + path_or_dataset: str = "naver-clova-ix/cord-v2", split: str = "train", **kwargs +) -> List[Dict[str, Any]]: + """Load and preprocess the CORD-V2 dataset for image-to-text fine-tuning.""" + dataset = load_dataset(path_or_dataset, split=split) + + def format(example): + ground_truth = json.loads(example["ground_truth"]) + if "gt_parses" in ground_truth: + assert isinstance(ground_truth["gt_parses"], list) + gt_jsons = ground_truth["gt_parses"] + else: + assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict) + gt_jsons = [ground_truth["gt_parse"]] + + text = random.choice([json2token(gt_json, sort_json_key=True) for gt_json in gt_jsons]) + + return { + "conversation": [ + { + "role": "user", + "content": [ + {"type": "image", "image": example["image"]}, + {"type": "text", "text": "Describe this image."}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": text}]}, + ], + } + + return [format(example) for example in dataset] + + +def make_medpix_dataset( + path_or_dataset: str = "mmoukouba/MedPix-VQA", split: str = "train", **kwargs +) -> List[Dict[str, Any]]: + """Load and preprocess the MedPix dataset for image-to-text fine-tuning.""" + dataset = load_dataset(path_or_dataset, split=split) + + def format(example): + return { + "conversation": [ + { + "role": "user", + "content": [ + {"type": "image", "image": example["image_id"]}, + {"type": "text", "text": example["question"]}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": example["answer"]}]}, + ], + } + + return [format(example) for example in dataset] + + +def make_cv17_dataset( + path_or_dataset: str = "ysdede/commonvoice_17_tr_fixed", split: str = "train", **kwargs +) -> List[Dict[str, Any]]: + """Load and preprocess the CommonVoice 17 dataset for audio-to-text fine-tuning.""" + dataset = load_dataset(path_or_dataset, split=split) + # Be robust to simple list-like datasets used in tests without `column_names` attr + try: + all_columns = dataset.column_names # type: ignore[attr-defined] + except Exception: + first_example = dataset[0] if len(dataset) > 0 else {} + all_columns = list(first_example.keys()) if isinstance(first_example, dict) else [] + if hasattr(dataset, "remove_columns"): + columns_to_remove = [col for col in all_columns if col not in ["audio", "transcription"]] + dataset = dataset.remove_columns(columns_to_remove) + + def format(example): + return { + "conversation": [ + {"role": "user", "content": "<|audio_1|>Transcribe the Turkish audio clip."}, + {"role": "assistant", "content": example["transcription"]}, + ], + "audio": (example["audio"]["array"], example["audio"]["sampling_rate"]), + } + + return [format(example) for example in dataset] diff --git a/src/megatron/bridge/data/vlm_datasets/hf_provider.py b/src/megatron/bridge/data/vlm_datasets/hf_provider.py new file mode 100644 index 000000000..9ca1d43c7 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/hf_provider.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provider that builds conversation datasets from HuggingFace datasets. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple + +import torch +from transformers import AutoProcessor + +from megatron.bridge.data.vlm_datasets.conversation_dataset import VLMConversationDataset +from megatron.bridge.data.vlm_datasets.hf_dataset_makers import ( + make_cord_v2_dataset, + make_cv17_dataset, + make_medpix_dataset, + make_rdr_dataset, +) +from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider + + +@dataclass(kw_only=True) +class HFDatasetConversationProvider(DatasetProvider): + """DatasetProvider that builds VLM conversation datasets from HF datasets. + + This provider leverages simple maker functions that return lists of examples + with a "conversation" schema understood by model processors. It binds a + HuggingFace `AutoProcessor` for the specified model and selects an + appropriate collate function for batching. + """ + + # Required to match model.seq_length (enforced by ConfigContainer.validate) + sequence_length: int + + # HF processor/model identifier (e.g., "Qwen/Qwen2.5-VL-3B-Instruct") + hf_processor_path: str + + # Select which maker to use. Must match a function defined in makers module + # like `make_rdr_dataset`, `make_cord_v2_dataset`, `make_medpix_dataset`, `make_cv17_dataset`. + maker_name: str + + # Optional parameters forwarded to the selected maker + maker_kwargs: Optional[Dict[str, Any]] = None + + # Optional collate override. If None, inferred from processor type. + collate_impl: Optional[Callable[[list, Any], Dict[str, torch.Tensor]]] = None + + # Keep parity with GPTDatasetConfig usage in batching utilities + skip_getting_attention_mask_from_dataset: bool = True + + # DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.) + dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single" + + def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]: + registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = { + "make_rdr_dataset": make_rdr_dataset, + "make_cord_v2_dataset": make_cord_v2_dataset, + "make_medpix_dataset": make_medpix_dataset, + "make_cv17_dataset": make_cv17_dataset, + } + if self.maker_name in registry: + return registry[self.maker_name] + # Allow passing function name alias without prefix, e.g., "rdr" -> make_rdr_dataset + alias_map = { + "rdr": "make_rdr_dataset", + "cord_v2": "make_cord_v2_dataset", + "medpix": "make_medpix_dataset", + "cv17": "make_cv17_dataset", + } + if self.maker_name in alias_map and alias_map[self.maker_name] in registry: + return registry[alias_map[self.maker_name]] + raise ValueError(f"Unknown maker_name: {self.maker_name}") + + def _build_split_dataset( + self, + split: str, + target_length: int, + processor: Any, + ) -> Optional[VLMConversationDataset]: + if target_length <= 0: + return None + maker = self._get_maker() + kwargs = dict(self.maker_kwargs or {}) + kwargs.setdefault("split", split) + base_examples = maker(**kwargs) # type: ignore[misc] + if not isinstance(base_examples, list) or len(base_examples) == 0: + raise ValueError(f"Maker '{self.maker_name}' returned no examples for split='{split}'") + return VLMConversationDataset( + base_examples=base_examples, + target_length=target_length, + processor=processor, + collate_impl=self.collate_impl, + ) + + def build_datasets(self, context: DatasetBuildContext) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]: + # Bind processor for the requested model + processor = AutoProcessor.from_pretrained(self.hf_processor_path, trust_remote_code=True) + + train_ds = self._build_split_dataset("train", context.train_samples, processor) + valid_ds = self._build_split_dataset("validation", context.valid_samples, processor) + test_ds = self._build_split_dataset("test", context.test_samples, processor) + + return train_ds, valid_ds, test_ds diff --git a/src/megatron/bridge/data/vlm_datasets/mock_provider.py b/src/megatron/bridge/data/vlm_datasets/mock_provider.py new file mode 100644 index 000000000..54297c95f --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/mock_provider.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Generic mock conversation-style VLM dataset and provider. + +This module produces synthetic image(s) and minimal conversations that are +compatible with HF `AutoProcessor.apply_chat_template` and the collate +functions defined in `collate.py`. It is processor-agnostic and can be used +with any multimodal model whose processor supports the standard conversation +schema and optional `images` argument. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple + +import numpy +from PIL import Image + +from megatron.bridge.data.vlm_datasets.conversation_dataset import VLMConversationDataset +from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider + + +@dataclass(kw_only=True) +class MockVLMConversationProvider(DatasetProvider): + """DatasetProvider for generic mock VLM conversation datasets. + + Builds train/valid/test datasets using a HF AutoProcessor and the + `MockVLMConversationDataset` implementation. Intended to work across + different VLM models whose processors support the conversation schema. + """ + + # Required to match model.seq_length + sequence_length: int + + # HF processor/model ID (e.g., Qwen/Qwen2.5-VL-3B-Instruct or other VLMs) + hf_processor_path: str + + # Sample generation options + prompt: str = "Describe this image." + random_seed: int = 0 + image_size: Tuple[int, int] = (256, 256) + pad_to_max_length: bool = True + create_attention_mask: bool = True + + # Keep parity with GPTDatasetConfig usage in batching utilities + skip_getting_attention_mask_from_dataset: bool = True + + # Number of images per sample + num_images: int = 1 + + # Default dataloader type for VLM providers + dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single" + + # HF AutoProcessor instance will be set during build + _processor: Optional[Any] = None + + def _make_base_examples(self) -> List[Dict[str, Any]]: + # Single minimal conversation example; dataset will repeat to target length + num_images = max(0, int(getattr(self, "num_images", 1))) + w, h = self.image_size + rng = numpy.random.default_rng(seed=self.random_seed) + images = None + if num_images > 0: + # Embed in-memory PIL images directly in the conversation so that + # qwen_vl_utils.process_vision_info can discover them. + images = [ + Image.fromarray(rng.integers(low=0, high=256, size=(h, w, 3), dtype=numpy.uint8), mode="RGB") + for _ in range(num_images) + ] + + content = [{"type": "image", "image": img} for img in images] if images is not None else [] + content.append({"type": "text", "text": self.prompt}) + messages = [ + {"role": "user", "content": content}, + {"role": "assistant", "content": [{"type": "text", "text": "dummy assistant response"}]}, + ] + return [{"conversation": messages}] + + def build_datasets(self, context: DatasetBuildContext): + from transformers import AutoProcessor + + # Initialize and store processor + self._processor = AutoProcessor.from_pretrained(self.hf_processor_path, trust_remote_code=True) + + base_examples = self._make_base_examples() + + def _maybe_make(size: int) -> Optional[VLMConversationDataset]: + if not size or size <= 0: + return None + return VLMConversationDataset( + base_examples=base_examples, + target_length=size, + processor=self._processor, + collate_impl=None, # infer collate from processor type (qwen2_5_collate_fn) + ) + + train_ds = _maybe_make(context.train_samples) + valid_ds = _maybe_make(context.valid_samples) + test_ds = _maybe_make(context.test_samples) + + return train_ds, valid_ds, test_ds diff --git a/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py b/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py new file mode 100644 index 000000000..e34be8967 --- /dev/null +++ b/src/megatron/bridge/data/vlm_datasets/preloaded_provider.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provider for datasets preloaded from JSON/JSONL files into conversation schema. +""" + +import json +import logging +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple + +from transformers import AutoProcessor + +from megatron.bridge.data.vlm_datasets.conversation_dataset import VLMConversationDataset +from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider + + +def _split_text_by_placeholders( + text: str, image_paths: List[str], video_paths: Optional[List[str]] = None +) -> List[Dict[str, Any]]: + """ + Split legacy text containing ""/"