Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7858117
update utils for transformers config in hydra
yaoyu-33 Sep 19, 2025
457bace
temp save
yaoyu-33 Sep 19, 2025
6937da4
Merge branch 'refs/heads/main' into qwen-25vl-training
yaoyu-33 Sep 22, 2025
8a51440
Merge branch 'refs/heads/main' into qwen-25vl-training
yaoyu-33 Sep 30, 2025
3bc6ba5
lint
yaoyu-33 Sep 30, 2025
8061e0f
revert qwen-vl changes in gpt
yaoyu-33 Sep 30, 2025
df4755a
revert qwen-vl changes in gpt #2
yaoyu-33 Sep 30, 2025
975efd2
Add mock dataset provider for qwen25 vl
yaoyu-33 Sep 30, 2025
be708c2
add qwen25 vl dataset support from auto
yaoyu-33 Sep 30, 2025
6822d34
lint
yaoyu-33 Sep 30, 2025
bc8c605
update _attn_implementation
yaoyu-33 Oct 1, 2025
689f491
update comments
yaoyu-33 Oct 1, 2025
4f0e90f
add preloaded dataset provider
yaoyu-33 Oct 1, 2025
2af0c2e
update _processor to a private attr
yaoyu-33 Oct 2, 2025
ccf6abe
update qwen training utils
yaoyu-33 Oct 2, 2025
94c6192
training bug fix
yaoyu-33 Oct 2, 2025
95d3002
fix finalize grad
yaoyu-33 Oct 3, 2025
4b7ef60
save qwen25 vl recipes
yaoyu-33 Oct 3, 2025
608117e
add padding logic for pp
yaoyu-33 Oct 3, 2025
a9f0e15
vlm step general
yaoyu-33 Oct 6, 2025
6ddd4b3
default update
yaoyu-33 Oct 6, 2025
f30aa39
Merge branch 'main' into qwen-25vl-training
yaoyu-33 Oct 6, 2025
e425113
update to model specific visual inputs, also update mock dataset to b…
yaoyu-33 Oct 6, 2025
5bc1f29
Merge branch 'main' into qwen-25vl-training
yaoyu-33 Oct 6, 2025
90a0ff0
add ci tests
yaoyu-33 Oct 7, 2025
49759bc
lint
yaoyu-33 Oct 8, 2025
62ffa88
update dependency
yaoyu-33 Oct 8, 2025
6af4e4c
build: add qwen-vl-utils and update lockfile
yaoyu-33 Oct 8, 2025
7e0ceaf
remove `start_of_response_token` use
yaoyu-33 Oct 8, 2025
a7e5fdc
add few more unit tests
yaoyu-33 Oct 8, 2025
1e44b97
fix wandb reinit issue
yaoyu-33 Oct 8, 2025
18012cd
Revert "fix wandb reinit issue"
yaoyu-33 Oct 9, 2025
b0b910e
lint
yaoyu-33 Oct 9, 2025
d2031ca
update and fix tests for vlm dataset
yaoyu-33 Oct 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/recipes/llama/pretrain_llama3_8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

Examples:
Basic usage with default configuration:
$ torchrun --nproc_per_node=8 pretrain_llama3_8b.py
$ torchrun --nproc_per_node=8 examples/recipes/llama/pretrain_llama3_8b.py

Using a custom YAML config file:
$ torchrun --nproc_per_node=8 pretrain_llama3_8b.py --config-file my_custom_config.yaml
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example override file for Qwen2.5-VL

model:
seq_length: 4096

train:
train_iters: 20
global_batch_size: 8
micro_batch_size: 1
eval_iters: 5

optimizer:
lr: 0.00025
min_lr: 0.000025

scheduler:
lr_warmup_iters: 10

checkpoint:
# Directory to save to. If null, no checkpoint will be saved.
save: null

dist:
use_megatron_fsdp: false
use_torch_fsdp2: false

logger:
log_interval: 1

dataset:
sequence_length: 4096

rng:
seed: 42

ddp:
grad_reduce_in_fp32: true


184 changes: 184 additions & 0 deletions examples/recipes/qwen_vl/pretrain_qwen25_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Qwen2.5-VL Pretraining Script with YAML and CLI Configuration Overrides.

This mirrors the Llama example flow and uses the Qwen-VL recipe helpers.
You can pick a specific recipe via `--recipe`, e.g., `qwen25_vl_3b_pretrain_config`,
`qwen25_vl_7b_pretrain_config`, etc.

Examples:
Basic usage with default configuration:
$ torchrun --nproc_per_node=8 pretrain_qwen25_vl.py
$ torchrun --nproc_per_node=8 examples/recipes/qwen_vl/pretrain_qwen25_vl.py


Using a custom YAML config file:
$ torchrun --nproc_per_node=8 pretrain_qwen25_vl.py --config-file conf/qwen25_vl_pretrain_override_example.yaml

CLI overrides:
$ torchrun --nproc_per_node=8 pretrain_qwen25_vl.py model.tensor_model_parallel_size=4 train.train_iters=100000

Selecting a specific recipe:
$ torchrun --nproc_per_node=8 pretrain_qwen25_vl.py --recipe qwen25_vl_7b_pretrain_config
"""

import argparse
import logging
import os
import sys
from pathlib import Path
from typing import Tuple

from omegaconf import OmegaConf

from megatron.bridge.recipes.qwen_vl import qwen25_vl as qwen_vl_recipes
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.utils.omegaconf_utils import (
apply_overrides,
create_omegaconf_dict_config,
parse_hydra_overrides,
)
from megatron.bridge.training.vlm_step import forward_step
from megatron.bridge.utils.common_utils import get_rank_safe


logger: logging.Logger = logging.getLogger(__name__)


SCRIPT_DIR: Path = Path(__file__).parent.resolve()
DEFAULT_CONFIG_FILENAME: str = "qwen25_vl_pretrain_override_example.yaml"
DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME


def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
"""Parse known script args and return remaining as Hydra-style overrides."""
parser = argparse.ArgumentParser(
description="Pretrain Qwen2.5-VL with YAML and CLI overrides",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--config-file",
type=str,
default=str(DEFAULT_CONFIG_FILE_PATH),
help="Path to the YAML OmegaConf override file. Default: conf/qwen25_vl_pretrain_override_example.yaml",
)
parser.add_argument(
"--data-path",
type=str,
default=None,
help="Path to JSON/JSONL dataset (preloaded conversation or legacy messages format).",
)
parser.add_argument(
"--image-folder",
type=str,
default=None,
help="Optional root for resolving relative image/video paths in dataset records.",
)
parser.add_argument(
"--dataset-type",
type=str,
choices=["mock", "preloaded", "hf"],
default=None,
help=(
"Dataset type to use: 'mock', 'preloaded', or 'hf'. "
"If not set, auto-detects based on --data-path/--use-preloaded."
),
)
parser.add_argument(
"--recipe",
type=str,
default="qwen25_vl_3b_pretrain_config",
help=(
"Name of the recipe function in megatron.bridge.recipes.qwen_vl.qwen25_vl to use, "
"e.g., qwen25_vl_3b_pretrain_config, qwen25_vl_7b_pretrain_config."
),
)
parser.add_argument(
"--use-preloaded",
action="store_true",
help="Use preloaded dataset provider (enabled automatically when --data-path is set).",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args, cli_dotlist_overrides = parser.parse_known_args()
return args, cli_dotlist_overrides


def main() -> None:
"""
Load the base VLM recipe config, apply YAML/CLI overrides, and start pretraining.
"""
args, cli_overrides = parse_cli_args()

logger.info("Megatron-Bridge Qwen2.5-VL Pretraining Script with YAML & CLI Overrides")
logger.info("-----------------------------------------------------------------------")

# Resolve the recipe function from the provided name
recipe_name = getattr(args, "recipe", "qwen25_vl_3b_pretrain_config")
available_recipes = [name for name in dir(qwen_vl_recipes) if name.endswith("_pretrain_config")]
if not hasattr(qwen_vl_recipes, recipe_name):
logger.error(
"Unknown recipe '%s'. Available recipes: %s",
recipe_name,
", ".join(sorted(available_recipes)),
)
sys.exit(2)
pretrain_config = getattr(qwen_vl_recipes, recipe_name)

# Determine dataset type based on CLI flag (overrides) or fall back to auto-detect
use_preloaded_flag = bool(args.data_path) or bool(getattr(args, "use_preloaded", False))
dataset_type = args.dataset_type or ("preloaded" if use_preloaded_flag else "mock")

cfg: ConfigContainer = pretrain_config(
dataset_type=dataset_type,
train_data_path=args.data_path,
valid_data_path=None,
test_data_path=None,
image_folder=args.image_folder,
)
logger.info("Loaded base configuration")

if get_rank_safe() == 0:
cfg.print_yaml()

merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg)

if args.config_file:
logger.debug(f"Loading YAML overrides from: {args.config_file}")
if not os.path.exists(args.config_file):
logger.error(f"Override YAML file not found: {args.config_file}")
sys.exit(1)
yaml_overrides_omega = OmegaConf.load(args.config_file)
merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega)

if cli_overrides:
logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}")
merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides)

final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True)
apply_overrides(cfg, final_overrides_as_dict, excluded_fields)

if get_rank_safe() == 0:
logger.info("--- Final Merged Configuration ---")
cfg.print_yaml()
logger.info("----------------------------------")

pretrain(config=cfg, forward_step_func=forward_step)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ dependencies = [
"tqdm>=4.67.1",
"hydra-core>1.3,<=1.3.2",
"megatron-core[dev,mlm]>=0.14.0a0,<0.16.0",
"qwen-vl-utils",
]


Expand Down
53 changes: 53 additions & 0 deletions src/megatron/bridge/data/vlm_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
VLM dataset utilities.

Public API re-exports:
- Makers: functions to build conversation examples from HF datasets
- Providers: classes that build PyTorch datasets bound to HF processors
- Collate fns: model-specific batch builders
"""

from .collate import COLLATE_FNS, default_collate_fn, phi4_mm_collate_fn, qwen2_5_collate_fn
from .conversation_dataset import VLMConversationDataset
from .hf_dataset_makers import (
make_cord_v2_dataset,
make_cv17_dataset,
make_medpix_dataset,
make_rdr_dataset,
)
from .hf_provider import HFDatasetConversationProvider
from .mock_provider import MockVLMConversationProvider
from .preloaded_provider import PreloadedVLMConversationProvider


__all__ = [
# Makers
"make_rdr_dataset",
"make_cord_v2_dataset",
"make_medpix_dataset",
"make_cv17_dataset",
# Dataset types/providers
"VLMConversationDataset",
"HFDatasetConversationProvider",
"PreloadedVLMConversationProvider",
"MockVLMConversationProvider",
# Collation utilities
"COLLATE_FNS",
"default_collate_fn",
"qwen2_5_collate_fn",
"phi4_mm_collate_fn",
]
Loading
Loading