diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md new file mode 100644 index 0000000000..d2ef719c0d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -0,0 +1,262 @@ +# Deterministic RL Training with vLLM + +This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. + +## Overview + +RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. + +The implementation: +1. Uses vLLM's batch-invariant kernels for forward passes +2. Implements custom backward passes for gradient computation +3. Provides weight conversion utilities between TorchTitan and vLLM formats + +### Features + +- Bitwise determinism: Same inputs produce identical outputs across runs +- Gradient support: Backward passes through vLLM operations +- Weight conversion: Utilities to convert between model formats + +Note: Currently supports single-device training only. + +## Architecture + +### Components + +1. `models/attention.py`: VLLMCompatibleFlashAttention + - Uses vLLM's Flash Attention for forward pass + - Implements custom backward pass for gradient computation + - Uses `num_splits=1` for deterministic behavior + +2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel + - Qwen3 model with merged gate/up projections matching vLLM format + - Uses VLLMRMSNorm with gradient support + +3. `batch_invariant_backward.py`: Backward passes for vLLM operations + - Registers gradients for vLLM's batch-invariant operations + - Supports matmul, linear, and RMSNorm + - Patches Flash Attention for autograd + +4. `weights_vllm_compat.py`: Weight conversion utilities + - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) + - Provides bidirectional conversion functions + +5. `simple_rl.py`: RL training loop + - Generates rollouts using vLLM engine + - Computes advantages using GRPO-style ranking + - Updates policy using PPO + +## Installation + +### Prerequisites + +```bash +# Install vLLM with deterministic support +pip install vllm + +# Install TorchTitan (from the repository root) +pip install -e . + +# Install additional dependencies +pip install transformers safetensors huggingface_hub tensorboard +``` + +### Enable Batch Invariance + +Initialize vLLM's batch-invariant mode before training: + +```python +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +init_batch_invariance() +``` + +## Usage + +### Quick Start + +```python +import torch +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from torchtitan.experiments.deterministic_vllm_rl import ( + enable_batch_invariant_backward_mode, + Qwen3VLLMCompatModel, +) + +# 1. Enable deterministic mode +init_batch_invariance() +enable_batch_invariant_backward_mode() + +# 2. Load model +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +model_args = Qwen3ModelArgs( + dim=2048, + n_layers=24, + n_heads=16, + n_kv_heads=2, + vocab_size=151936, +) +model = Qwen3VLLMCompatModel(model_args) + +# 3. Forward pass (deterministic) +input_ids = torch.randint(0, 151936, (2, 128), device='cuda') +logits = model(input_ids) + +# 4. Backward pass +loss = logits.sum() +loss.backward() +``` + +### Full RL Training + +Run the RL training loop: + +```bash +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +``` + +This will: +1. Download Qwen3-1.7B from HuggingFace +2. Initialize vLLM engine for rollouts +3. Generate samples for training prompts +4. Compute rewards and advantages +5. Update the policy using PPO +6. Log metrics to TensorBoard + +View training progress: +```bash +tensorboard --logdir=./outputs/rl_training +``` + +## How It Works + +### Deterministic Forward Pass + +vLLM's batch-invariant mode makes operations deterministic: + +```python +# These operations are deterministic when batch_invariance is enabled +y = torch.matmul(a, b) # Uses vLLM's deterministic matmul +output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA +``` + +### Backward Pass with Gradients + +Custom backward passes: +1. Re-compute attention weights deterministically +2. Use standard chain rule for gradients +3. Apply gradients through vLLM's deterministic operations + +```python +class FlashAttnWithBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, ...): + # Use vLLM's forward implementation + return flash_attn_varlen_func(q, k, v, num_splits=1, ...) + + @staticmethod + def backward(ctx, grad_output): + # Compute gradients deterministically + # (re-compute attention weights and apply chain rule) + return grad_q, grad_k, grad_v, ... +``` + +### Bitwise Determinism Verification + +The training loop compares logprobs from vLLM and TorchTitan: + +```python +# During training, compare logprobs +vllm_logprobs = [from vLLM rollout] +titan_logprobs = [from TorchTitan forward pass] + +assert torch.equal(vllm_logprobs, titan_logprobs) +``` + +## Testing + +Run the test suite: + +```bash +cd torchtitan/experiments/deterministic_vllm_rl/tests + +# Test backward passes +python test_batch_invariant_backward.py + +# Test determinism +python test_exact_determinism.py +``` + +## Technical Details + +### Why Determinism Matters for RL + +RL training steps: +1. Generate rollouts by sampling from the policy +2. Compute rewards based on the samples +3. Update the policy using gradients + +If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. + +This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. + +### Performance + +- Rollout speed: Uses vLLM's optimized kernels +- Training speed: Similar to standard TorchTitan +- Memory: Saves activations for custom backward passes + +### Limitations + +1. Custom backward requires uniform sequence lengths +2. Only causal attention is supported +3. Requires NVIDIA GPUs with Flash Attention support + +## Project Structure + +``` +deterministic_vllm_rl/ +├── README.md # Documentation +├── __init__.py # Package initialization +├── batch_invariant_backward.py # Backward passes for vLLM ops +├── weights_vllm_compat.py # Weight conversion utilities +├── simple_rl.py # RL training loop +├── models/ +│ ├── __init__.py +│ ├── attention.py # VLLMCompatibleFlashAttention +│ └── qwen3/ +│ ├── __init__.py +│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model +├── weights/ +│ ├── __init__.py +│ ├── converter.py # Weight conversion script +│ └── README.md # Weight conversion documentation +└── tests/ + ├── __init__.py + ├── test_batch_invariant_backward.py # Test backward passes + └── test_exact_determinism.py # Test determinism +``` + +## TODO + +- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. +- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. +- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. + +## Contributing + +This experiment is part of TorchTitan. To contribute: + +1. Test your changes with `pytest tests/` +2. Verify bitwise determinism is maintained +3. Update this README if adding new features + +## References + +- [vLLM Documentation](https://docs.vllm.ai/) +- [Flash Attention Paper](https://arxiv.org/abs/2205.14135) +- [PPO Algorithm](https://arxiv.org/abs/1707.06347) +- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) + +## License + +This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py new file mode 100644 index 0000000000..067555251f --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Deterministic RL training with vLLM experiment. + +This experiment provides tools for bitwise-deterministic reinforcement learning +training using vLLM for fast rollouts and TorchTitan for training. + +Key components: +- VLLMCompatibleFlashAttention: Flash attention with custom backward pass +- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections +- batch_invariant_backward: Gradient support for vLLM's deterministic operations +- simple_rl: End-to-end RL training loop +""" + +from .batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + rms_norm_with_gradients, + silu_and_mul_with_gradients, +) +from .models import VLLMCompatibleFlashAttention +from .models.qwen3 import Qwen3VLLMCompatModel + +__all__ = [ + "VLLMCompatibleFlashAttention", + "Qwen3VLLMCompatModel", + "enable_batch_invariant_backward_mode", + "rms_norm_with_gradients", + "silu_and_mul_with_gradients", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py new file mode 100644 index 0000000000..faccf8265d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py @@ -0,0 +1,378 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Batch-invariant operations with backward pass support. + +This module adds gradient support to vLLM's deterministic batch_invariant mode +by registering backward operations that also use vLLM's deterministic kernels. + +Key architecture: +- Forward: Uses vLLM's batch_invariant Triton kernels (deterministic) +- Backward: Also uses vLLM's batch_invariant kernels (deterministic) + +This achieves bitwise-deterministic RL training where both rollouts (forward) +and training (forward + backward) produce identical results. + +Usage: + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + from batch_invariant_backward import enable_batch_invariant_backward_mode + + # Initialize vLLM's deterministic mode first + init_batch_invariance() + + # Then enable gradient support + enable_batch_invariant_backward_mode() + + # Now all operations are deterministic AND support gradients + model = MyModel() + output = model(input) # deterministic forward + loss = compute_loss(output) + loss.backward() # gradients work with deterministic backward! +""" + +import torch +from torch.autograd import Function + + +# ============================================================================ +# Custom autograd Functions for vLLM operations +# ============================================================================ + + +class SiluAndMulFunction(Function): + """ + Autograd function for vLLM's SiluAndMul activation. + + Forward: splits input into [gate, up], returns silu(gate) * up + where silu(x) = x * sigmoid(x) + """ + + @staticmethod + def forward(ctx, x): + """ + Forward pass using vLLM's SiluAndMul. + + Args: + x: Input tensor [..., hidden_dim * 2] where first half is gate, second half is up + + Returns: + output: silu(gate) * up, shape [..., hidden_dim] + """ + from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul + + # Use vLLM's implementation for forward + vllm_silu_and_mul = VLLMSiluAndMul() + output = vllm_silu_and_mul(x) + + # Save for backward + ctx.save_for_backward(x) + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for SiluAndMul. + + Let gate = x[:d], up = x[d:] where d = hidden_dim + Forward: out = silu(gate) * up = (gate * sigmoid(gate)) * up + + Gradients: + - grad_gate = grad_out * up * d_silu(gate) + - grad_up = grad_out * silu(gate) + + where d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + """ + (x,) = ctx.saved_tensors + + # Split input into gate and up + d = x.shape[-1] // 2 + gate = x[..., :d] + up = x[..., d:] + + # Compute sigmoid and silu for backward + sigmoid_gate = torch.sigmoid(gate) + silu_gate = gate * sigmoid_gate + + # Gradient of silu: d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + d_silu_gate = sigmoid_gate * (1 + gate * (1 - sigmoid_gate)) + + # Compute gradients + grad_gate = grad_output * up * d_silu_gate + grad_up = grad_output * silu_gate + + # Concatenate gradients + grad_x = torch.cat([grad_gate, grad_up], dim=-1) + + return grad_x + + +class RMSNormFunction(Function): + """ + Autograd function for RMS normalization using vLLM's Triton kernel in forward + and batch-invariant operations in backward. + """ + + @staticmethod + def forward(ctx, input, weight, eps): + """ + Forward pass using vLLM's rms_norm Triton kernel. + + Args: + input: Input tensor [*, hidden_size] + weight: Weight tensor [hidden_size] + eps: Epsilon for numerical stability + + Returns: + output: Normalized and scaled tensor [*, hidden_size] + """ + from vllm.model_executor.layers.batch_invariant import rms_norm as vllm_rms_norm + + # Use vLLM's Triton kernel for forward (deterministic) + output = vllm_rms_norm(input, weight, eps) + + # Save for backward + ctx.save_for_backward(input, weight) + ctx.eps = eps + + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass using batch-invariant PyTorch operations. + + Returns: + (grad_input, grad_weight, None) + """ + input, weight = ctx.saved_tensors + eps = ctx.eps + + # Compute forward pass values needed for backward + # variance = mean(x^2) along last dim + variance = (input * input).mean(dim=-1, keepdim=True) + rms = torch.sqrt(variance + eps) + x_norm = input / rms + + # Gradient w.r.t. weight + # grad_weight = sum(grad_output * x_norm) over all dims except last + grad_weight = (grad_output * x_norm).sum(dim=tuple(range(grad_output.ndim - 1))) + + # Gradient w.r.t. input + # grad_x_norm = grad_output * weight + grad_x_norm = grad_output * weight + + # grad_x = (grad_x_norm - mean(grad_x_norm * x_norm) * x_norm) / rms + mean_term = (grad_x_norm * x_norm).mean(dim=-1, keepdim=True) + grad_input = (grad_x_norm - mean_term * x_norm) / rms + + return grad_input, grad_weight, None + + +# ============================================================================ +# Backward operation implementations for autograd +# ============================================================================ + + +def matmul_backward_impl(grad_output, self, other, output_mask): + """ + Backward pass for matmul: y = matmul(a, b) + Returns: (grad_a, grad_b) + + Args: + grad_output: Gradient from downstream + self: First input tensor (a) + other: Second input tensor (b) + output_mask: List of bools indicating which gradients to compute [self, other] + + grad_a = grad_output @ b.T + grad_b = a.T @ grad_output + + Uses torch.matmul which is overridden by vLLM's batch_invariant mode! + """ + grad_self = grad_other = None + + # output_mask is a list [compute_grad_self, compute_grad_other] + compute_grad_self = output_mask[0] if len(output_mask) > 0 else True + compute_grad_other = output_mask[1] if len(output_mask) > 1 else True + + if compute_grad_self: + # grad_self = grad_output @ other.T + if other.ndim == 2: + grad_self = torch.matmul(grad_output, other.t()) + elif other.ndim == 3: + grad_self = torch.matmul(grad_output, other.transpose(-2, -1)) + else: + grad_self = torch.matmul(grad_output, other.transpose(-2, -1)) + + if compute_grad_other: + # grad_other = self.T @ grad_output + if self.ndim == 2: + grad_other = torch.matmul(self.t(), grad_output) + elif self.ndim == 3: + grad_other = torch.matmul(self.transpose(-2, -1), grad_output) + else: + grad_other = torch.matmul(self.transpose(-2, -1), grad_output) + + return grad_self, grad_other + + +def linear_backward_impl(grad_output, input, weight, output_mask): + """ + Backward pass for linear: y = input @ weight.T + bias + Returns: (grad_input, grad_weight, grad_bias) + + Args: + grad_output: Gradient from downstream (actually the saved input!) + input: Input tensor (actually grad_output!) + weight: Weight tensor + output_mask: List of bools indicating which gradients to compute [input, weight, bias] + + PyTorch passes args in weird order: (saved_input, grad_output, weight, output_mask) + So we swap the first two args in our implementation. + """ + # Swap: PyTorch passes (saved_input, grad_output, ...) but we want (grad_output, input, ...) + input, grad_output = grad_output, input + + grad_input = grad_weight = grad_bias = None + + # output_mask is a list [compute_grad_input, compute_grad_weight, compute_grad_bias] + compute_grad_input = output_mask[0] if len(output_mask) > 0 else True + compute_grad_weight = output_mask[1] if len(output_mask) > 1 else True + compute_grad_bias = output_mask[2] if len(output_mask) > 2 else True + + if compute_grad_input: + # grad_input = grad_output @ weight + grad_input = torch.matmul(grad_output, weight) + + if compute_grad_weight: + # PyTorch linear: y = x @ W.T + b where W is [out, in] + # Backward: grad_W = grad_y.T @ x + # grad_output: (batch, out), input: (batch, in) + # grad_output.T @ input: (out, batch) @ (batch, in) = (out, in) ✓ + + # Handle multi-dimensional inputs + if input.ndim == 3: + # Reshape for matmul: (batch, seq, in) -> (batch*seq, in) + input_2d = input.reshape(-1, input.shape[-1]) + grad_output_2d = grad_output.reshape(-1, grad_output.shape[-1]) + # grad_output_2d: (batch*seq, out), input_2d: (batch*seq, in) + # grad_output_2d.T @ input_2d: (out, batch*seq) @ (batch*seq, in) = (out, in) ✓ + grad_weight = torch.matmul(grad_output_2d.transpose(0, 1), input_2d) + else: + # input: (batch, in), grad_output: (batch, out) + # grad_output.T @ input: (out, batch) @ (batch, in) = (out, in) ✓ + grad_weight = torch.matmul(grad_output.transpose(0, 1), input) + + if compute_grad_bias: + # grad_bias = sum(grad_output) along all dims except last + grad_bias = grad_output.sum(dim=tuple(range(grad_output.ndim - 1))) + + return grad_input, grad_weight, grad_bias + + +# ============================================================================ +# Registration +# ============================================================================ + +_batch_invariant_backward_mode = False +_batch_invariant_backward_lib = None + + +def enable_batch_invariant_backward_mode(): + """Enable batch invariant backward mode to support gradients. + + This function adds backward pass support to vLLM's existing batch_invariant + implementations by registering the backward operations. vLLM handles all the + forward passes, we just add gradient support. + """ + global _batch_invariant_backward_mode, _batch_invariant_backward_lib + + if _batch_invariant_backward_mode: + return + + # Get vLLM's batch_invariant library (already created by init_batch_invariance) + from vllm.model_executor.layers import batch_invariant as vllm_bi + + if ( + not hasattr(vllm_bi, "_batch_invariant_LIB") + or vllm_bi._batch_invariant_LIB is None + ): + raise RuntimeError( + "vLLM's batch_invariant mode is not initialized. " + "Call init_batch_invariance() first." + ) + + # Use vLLM's existing library - don't destroy it! + _batch_invariant_backward_lib = vllm_bi._batch_invariant_LIB + + # Just add the backward operations - everything else is already handled by vLLM + _batch_invariant_backward_lib.impl( + "aten::matmul_backward", matmul_backward_impl, "CUDA" + ) + _batch_invariant_backward_lib.impl( + "aten::linear_backward", linear_backward_impl, "CUDA" + ) + + _batch_invariant_backward_mode = True + + +def disable_batch_invariant_backward_mode(): + """Disable batch invariant backward mode.""" + global _batch_invariant_backward_mode, _batch_invariant_backward_lib + + if _batch_invariant_backward_lib is not None: + _batch_invariant_backward_lib._destroy() + + _batch_invariant_backward_mode = False + _batch_invariant_backward_lib = None + + +def is_batch_invariant_backward_mode_enabled(): + """Check if batch invariant backward mode is enabled.""" + return _batch_invariant_backward_mode + + +# ============================================================================ +# Public API for gradient-enabled vLLM operations +# ============================================================================ + + +def rms_norm_with_gradients( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + RMS normalization with gradient support. + + Uses vLLM's Triton kernel for forward pass (deterministic) and + batch-invariant PyTorch operations for backward pass. + + Args: + input: Input tensor [*, hidden_size] + weight: Weight tensor [hidden_size] + eps: Epsilon for numerical stability + + Returns: + output: Normalized and scaled tensor [*, hidden_size] + """ + return RMSNormFunction.apply(input, weight, eps) + + +def silu_and_mul_with_gradients(x: torch.Tensor) -> torch.Tensor: + """ + SiluAndMul activation with gradient support. + + Uses vLLM's implementation for forward pass (deterministic) and + implements proper backward pass for training. + + Args: + x: Input tensor [..., hidden_dim * 2] where first half is gate, second half is up + + Returns: + output: silu(gate) * up, shape [..., hidden_dim] + """ + return SiluAndMulFunction.apply(x) diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py new file mode 100644 index 0000000000..c8c11a170a --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Models for deterministic vLLM RL training. +""" + +from .attention import VLLMCompatibleFlashAttention + +__all__ = ["VLLMCompatibleFlashAttention"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py new file mode 100644 index 0000000000..33dd5a140d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/attention.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +vLLM-compatible Flash Attention implementation for deterministic RL training. +""" + +import torch +from vllm.vllm_flash_attn import flash_attn_varlen_func + + +class VLLMCompatibleFlashAttention(torch.nn.Module): + """Wrapper around FlashAttention as used by VLLM""" + + def __init__(self) -> None: + super().__init__() + self.flash_attn_varlen_func = flash_attn_varlen_func + from vllm.attention.utils.fa_utils import get_flash_attn_version + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + + self.vllm_is_batch_invariant = vllm_is_batch_invariant + self.fa_version = get_flash_attn_version() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) + # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) + # We need to transpose to (batch, seq_len, num_heads, head_dim) + + # Input is (batch, num_heads, seq_len, head_dim) - need to transpose + q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Get dimensions + batch_size, seq_len, num_heads, head_dim = q.shape + + # Convert to varlen format: flatten batch and sequence dimensions + # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) + q_varlen = q.reshape(-1, num_heads, head_dim) + k_varlen = k.reshape(-1, k.shape[2], head_dim) + v_varlen = v.reshape(-1, v.shape[2], head_dim) + + # Create cumulative sequence lengths + # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device + ) + + # Wrap Flash Attention with manual backward pass + class FlashAttnWithBackward(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + seq_len, + scale, + num_splits, + flash_fn, + fa_version, + ): + # Call flash attention for forward (fast) + output = flash_fn( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=scale, + causal=True, + num_splits=num_splits, + fa_version=fa_version, + ) + # Save for backward + ctx.save_for_backward(q, k, v, output) + ctx.scale = scale + ctx.seq_len = seq_len + return output + + @staticmethod + def backward(ctx, grad_output): + q, k, v, output = ctx.saved_tensors + scale = ctx.scale + seq_len = ctx.seq_len + + # Reshape from varlen back to batch format for attention computation + # Assume uniform sequence lengths (batch_size = total_tokens / seq_len) + total_tokens = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + batch_size = total_tokens // seq_len + + q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) + k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) + v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) + out_batch = output.reshape(batch_size, seq_len, num_heads, head_dim) + grad_out_batch = grad_output.reshape( + batch_size, seq_len, num_heads, head_dim + ) + + # Transpose to (batch, num_heads, seq_len, head_dim) + q_t = q_batch.transpose(1, 2) + k_t = k_batch.transpose(1, 2) + v_t = v_batch.transpose(1, 2) + out_t = out_batch.transpose(1, 2) + grad_out_t = grad_out_batch.transpose(1, 2) + + # Compute attention scores: QK^T + # q_t: (B, H, N, D), k_t: (B, H, N, D) -> scores: (B, H, N, N) + scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale + + # Apply causal mask + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), + diagonal=1, + ) + scores = scores.masked_fill(causal_mask, float("-inf")) + + # Softmax + attn_weights = torch.nn.functional.softmax( + scores, dim=-1 + ) # (B, H, N, N) + + # Backward through attention + # out = attn_weights @ v + # grad_v = attn_weights^T @ grad_out + grad_v_t = torch.matmul(attn_weights.transpose(-2, -1), grad_out_t) + + # grad_attn_weights = grad_out @ v^T + grad_attn_weights = torch.matmul(grad_out_t, v_t.transpose(-2, -1)) + + # Backward through softmax + # d_softmax = attn_weights * (grad_attn_weights - sum(grad_attn_weights * attn_weights)) + sum_term = (grad_attn_weights * attn_weights).sum(dim=-1, keepdim=True) + grad_scores = attn_weights * (grad_attn_weights - sum_term) + + # Apply causal mask to gradients + grad_scores = grad_scores.masked_fill(causal_mask, 0.0) + + # Backward through QK^T and scale + grad_scores = grad_scores * scale + + # grad_q = grad_scores @ K + grad_q_t = torch.matmul(grad_scores, k_t) + + # grad_k = grad_scores^T @ Q + grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t) + + # Transpose back and reshape to varlen format + grad_q = grad_q_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_k = grad_k_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + grad_v = grad_v_t.transpose(1, 2).reshape( + total_tokens, num_heads, head_dim + ) + + return grad_q, grad_k, grad_v, None, None, None, None, None, None + + # Call Flash Attention varlen with custom backward + output_varlen = FlashAttnWithBackward.apply( + q_varlen, + k_varlen, + v_varlen, + cu_seqlens, + seq_len, + scale, + 1 if self.vllm_is_batch_invariant() else 0, + self.flash_attn_varlen_func, + self.fa_version, + ) + + # Convert back to batch format + # (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim) + output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py new file mode 100644 index 0000000000..10f49db8b5 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Qwen3 model with vLLM compatibility for deterministic RL training. +""" + +from .model_vllm_compat import Qwen3VLLMCompatModel + +__all__ = ["Qwen3VLLMCompatModel"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py new file mode 100644 index 0000000000..dd84665091 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Qwen3 model compatible with vLLM's implementation +# Uses merged gate_up projections and vLLM Flash Attention + +import torch +from torch import nn + +from torchtitan.components.tokenizer import BaseTokenizer + +# Import gradient-enabled operations from experiment utilities +from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + rms_norm_with_gradients, + silu_and_mul_with_gradients, +) + +# Import from main torchtitan +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +# Import from local experiment's models +from ..attention import VLLMCompatibleFlashAttention + + +# RoPE functions (same as original) +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + idx_theta = torch.outer(t, freqs).float() + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting.""" + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = xq.shape[-1] + rope_cache = reshape_for_broadcast(rope_cache, xq) + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class VLLMRMSNorm(nn.Module): + """ + RMSNorm using vLLM's exact Triton kernel for bitwise determinism. + Compatible with PyTorch's nn.RMSNorm interface but uses vLLM's implementation. + + Supports gradients through a custom autograd function that uses vLLM's + kernel for forward and batch-invariant PyTorch ops for backward. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Use vLLM's RMSNorm with gradient support for training + return rms_norm_with_gradients(x, self.weight, self.eps) + + def reset_parameters(self): + nn.init.ones_(self.weight) + + +class FeedForwardVLLMCompat(nn.Module): + """ + FeedForward module compatible with vLLM implementation. + Uses merged gate_up projection like vLLM. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + + # Merged gate and up projections (like vLLM's gate_up_proj) + self.gate_up_proj = nn.Linear(dim, hidden_dim * 2, bias=False) + + # Down projection (like vLLM's down_proj) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def forward(self, x): + # Project to gate and up in one go + gate_up = self.gate_up_proj(x) + # Apply SiluAndMul activation with gradient support + activated = silu_and_mul_with_gradients(gate_up) + # Project down + output = self.down_proj(activated) + return output + + def init_weights(self, init_std: float): + # Initialize like vLLM + nn.init.trunc_normal_(self.gate_up_proj.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.down_proj.weight, mean=0.0, std=init_std) + + +class Attention(nn.Module): + """ + Multi-head attention module compatible with vLLM. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.head_dim + self.scaling = self.head_dim**-0.5 + + # QK norm (Qwen3 specific) - use vLLM's RMSNorm + if model_args.qk_norm: + self.q_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + self.k_norm = VLLMRMSNorm(self.head_dim, eps=model_args.norm_eps) + else: + self.q_norm = None + self.k_norm = None + + # QKV projections + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + # Always use vLLM compatible flash attention + self.inner_attention = VLLMCompatibleFlashAttention() + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + if self.q_norm is not None: + self.q_norm.reset_parameters() + if self.k_norm is not None: + self.k_norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Reshape to heads + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + # Apply QK norm + if self.q_norm: + xq = self.q_norm(xq) + if self.k_norm: + xk = self.k_norm(xk) + + # Apply rotary embedding + xq, xk = apply_rotary_emb(xq, xk, rope_cache) + + # Repeat k/v heads if needed + keys = repeat_kv(xk, self.n_rep) + values = repeat_kv(xv, self.n_rep) + + # Transpose for attention + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) + xv = values.transpose(1, 2) + + # Apply flash attention (vLLM compatible, no flex attention) + assert ( + attention_masks is None + ), "vLLM compat mode doesn't use flex attention masks" + output = self.inner_attention(xq, xk, xv, scale=self.scaling) + + # Transpose back + output = output.transpose(1, 2).contiguous() + output = output.view(bs, seqlen, -1) + + return self.wo(output) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock with vLLM-compatible FFN. + """ + + def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + + self.attention = Attention(model_args) + + # Use vLLM-compatible FFN with merged projections + self.feed_forward = FeedForwardVLLMCompat( + dim=model_args.dim, hidden_dim=model_args.hidden_dim + ) + + self.attention_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + # Self attention with residual + attn_norm_out = self.attention_norm(x) + x = x + self.attention(attn_norm_out, rope_cache, attention_masks) + + # FFN with residual + ffn_norm_out = self.ffn_norm(x) + x = x + self.feed_forward(ffn_norm_out) + + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): + """ + Qwen3 model with vLLM-compatible implementation. + Uses merged gate_up projections and vLLM Flash Attention. + """ + + def __init__(self, model_args: Qwen3ModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.eos_id = model_args.eos_id + self.head_dim = model_args.head_dim + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = VLLMRMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + # IMPORTANT: To match vLLM's behavior and Qwen3's config + # (tie_word_embeddings: true), tie output layer weights to + # embedding weights. When either weight updates during training, + # both update together + self.output.weight = self.tok_embeddings.weight + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType | None: + # vLLM compat mode: no flex attention masks + return None + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.rope_cache, attention_masks) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py new file mode 100644 index 0000000000..ffc7d52eb0 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py @@ -0,0 +1,1227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Simple RL training loop with GRPO-style advantage estimation. + +This demonstrates: +1. Loading a model in TorchTitan format for training +2. Converting weights to vLLM format for fast rollouts +3. Generating samples using vLLM +4. Computing rewards (trivial/random for now) +5. Computing advantages using GRPO-style group ranking +6. Performing a policy gradient update on TorchTitan model +7. Optional real dataset support (GSM8K math dataset) +""" + +import os +import re + +import torch +import torch.nn.functional as F +from huggingface_hub import snapshot_download +from safetensors.torch import load_file, save_file +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance + +from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( + torchtitan_to_vllm, + vllm_to_torchtitan, +) +from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) + +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs + +init_batch_invariance() + + +class VLLMRolloutEngine: + """ + vLLM engine for fast rollouts with weight updates. + + Note: vLLM loads from model_config.model path, so we create a temporary + directory with updated weights and restart the engine. This is faster than + recreating temp dirs repeatedly and handles config/tokenizer files properly. + + Args: + model_path: Path to HuggingFace model (for config/tokenizer) + temp_checkpoint_dir: Directory to save temporary weight checkpoints + """ + + def __init__(self, model_path: str, temp_checkpoint_dir: str = "./converted"): + self.base_model_path = model_path + self.temp_model_dir = os.path.abspath( + os.path.join(temp_checkpoint_dir, "vllm_temp_model") + ) + os.makedirs(self.temp_model_dir, exist_ok=True) + + import glob + + # Copy config/tokenizer files from base model to temp dir + import shutil + + for file in [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "merges.txt", + "vocab.json", + ]: + src = os.path.join(model_path, file) + if os.path.exists(src): + shutil.copy2(src, self.temp_model_dir) + + # Copy the original model shard files if they exist + # We'll overwrite these with our single model.safetensors later + for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): + dst = os.path.join(self.temp_model_dir, os.path.basename(shard_file)) + shutil.copy2(shard_file, dst) + + # Copy index file if it exists + index_file = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file): + shutil.copy2(index_file, self.temp_model_dir) + + self.llm = None + print("vLLM rollout engine initialized (will load on first use)") + + def update_weights(self, vllm_compat_state: dict) -> None: + """ + Update vLLM model weights from vLLM-compat state dict. + + This converts weights to vLLM format, saves them, and reloads using + vLLM's reload_weights() API after updating the model path config. + + Args: + vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + """ + # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) + vllm_state = torchtitan_to_vllm(vllm_compat_state) + + # Save to temp model directory + checkpoint_path = os.path.join(self.temp_model_dir, "model.safetensors") + + # Update the shard files that vLLM will actually load + # We need to split our weights to match the original 2-shard structure + import glob + import json + + shard_files = sorted( + glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) + ) + index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") + + if len(shard_files) == 2 and os.path.exists(index_file): + # Load the index to see which weights go in which shard + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + + # Split weights according to the index + shard1_weights = {} + shard2_weights = {} + + for key, value in vllm_state.items(): + shard_file = weight_map.get(key, shard_files[0]) + if "model-00001-of-00002" in shard_file: + shard1_weights[key] = value + else: + shard2_weights[key] = value + + # Ensure weights stay in bfloat16 + shard1_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard1_weights.items() + } + shard2_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard2_weights.items() + } + + # Save to the shard files + save_file(shard1_weights, shard_files[0]) + save_file(shard2_weights, shard_files[1]) + else: + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) + + # First time: create the engine + if self.llm is None: + self.llm = LLM( + model=self.temp_model_dir, + trust_remote_code=True, + max_model_len=2048, + dtype="bfloat16", + gpu_memory_utilization=0.3, # Reduced from 0.5 + seed=42, # Fixed seed for determinism + enforce_eager=True, + ) + print("✓ Created new vLLM engine") + else: + # Use collective_rpc to call reload_weights on all workers + # This reloads weights from temp_model_dir without recreating the engine + self.llm.collective_rpc("reload_weights") + + @torch.no_grad() + def generate( + self, + prompt_texts: list[str], + max_new_tokens: int = 20, + temperature: float = 1.0, + n_samples_per_prompt: int = 4, + ) -> tuple[ + list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] + ]: + """ + Generate samples using vLLM. + + Args: + prompt_texts: List of prompt strings + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + n_samples_per_prompt: Number of samples per prompt + + Returns: + completions: List of completion strings + log_probs: [batch] - Sum of log probs for each completion + token_ids: List of token ID lists for each completion (generated tokens only) + token_log_probs: List of per-token log prob lists for each completion + prompt_token_ids: List of prompt token ID lists for each completion + """ + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + n=n_samples_per_prompt, + seed=42, + logprobs=1, + prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + ) + + outputs = self.llm.generate(prompt_texts, sampling_params) + + # Extract completions and log probs + completions = [] + log_probs_list = [] + token_ids_list = [] + token_log_probs_list = [] + prompt_token_ids_list = [] + + for output in outputs: + # Extract prompt token IDs from the output + prompt_token_ids = output.prompt_token_ids + + for sample in output.outputs: + completions.append(sample.text) + + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) + + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) + + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) + + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) + + log_probs = torch.tensor(log_probs_list, dtype=torch.float32) + + return ( + completions, + log_probs, + token_ids_list, + token_log_probs_list, + prompt_token_ids_list, + ) + + def __del__(self): + """Cleanup vLLM engine.""" + if hasattr(self, "llm"): + del self.llm + torch.cuda.empty_cache() + + +def download_and_convert_model( + model_name: str, cache_dir: str = "./models", output_dir: str = "./converted" +) -> tuple[str, str]: + """ + Download model from HuggingFace and convert to TorchTitan format. + + Args: + model_name: HuggingFace model name (e.g., "Qwen/Qwen3-1.7B") + cache_dir: Directory to cache the downloaded model + output_dir: Directory to save converted weights + + Returns: + titan_checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to downloaded HuggingFace model + """ + os.makedirs(output_dir, exist_ok=True) + + # Download model from HuggingFace + print(f"Downloading {model_name} from HuggingFace...") + model_path = snapshot_download( + model_name, + cache_dir=cache_dir, + allow_patterns=["*.safetensors", "*.json", "*.txt", "tokenizer.model"], + ) + print(f" Downloaded to: {model_path}") + + # Convert to TorchTitan format + print("Converting weights to TorchTitan format...") + titan_state = vllm_to_torchtitan(model_path) + titan_checkpoint_path = os.path.join(output_dir, "qwen3_torchtitan.safetensors") + save_file(titan_state, titan_checkpoint_path) + print(f" Saved TorchTitan weights to: {titan_checkpoint_path}") + + return titan_checkpoint_path, model_path + + +def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = True): + """ + Load TorchTitan model from checkpoint. + + Args: + checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model (for config) + use_vllm_compat: If True, use vLLM-compatible model, else use standard model + + Returns: + model: Loaded TorchTitan model + """ + # Load HuggingFace config + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Create model args + model_args = Qwen3ModelArgs( + dim=hf_config.hidden_size, + n_layers=hf_config.num_hidden_layers, + n_heads=hf_config.num_attention_heads, + n_kv_heads=hf_config.num_key_value_heads, + vocab_size=hf_config.vocab_size, + head_dim=getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ), + hidden_dim=hf_config.intermediate_size, + norm_eps=hf_config.rms_norm_eps, + rope_theta=hf_config.rope_theta, + max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), + qk_norm=True, + depth_init=True, + eos_id=getattr(hf_config, "eos_token_id", 151645), + ) + + # state_dict is in standard TorchTitan format (w1, w2, w3) + state_dict = load_file(checkpoint_path) + + if use_vllm_compat: + # Create and load model (using vLLM-compat for bitwise determinism) + from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + Qwen3VLLMCompatModel, + ) + + model = Qwen3VLLMCompatModel(model_args) + # Convert to vLLM-compat format (merged gate_up_proj, down_proj) + vllm_compat_state = torchtitan_to_vllm_compat(state_dict) + model.load_state_dict(vllm_compat_state, strict=False) + else: + # Use standard TorchTitan model + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=False) + + model.to(torch.bfloat16) + + return model + + +def extract_numeric_answer(text: str) -> str | None: + """ + Extract numeric answer from model completion. + + Looks for patterns like "#### 123" or final numbers in the text. + + Args: + text: Completion text + + Returns: + Extracted answer as string, or None if not found + """ + # GSM8K uses #### to denote the final answer + match = re.search(r"####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)", text) + if match: + # Remove commas from numbers + return match.group(1).replace(",", "") + + # Fallback: look for last number in text + numbers = re.findall(r"-?\d+(?:,\d+)*(?:\.\d+)?", text) + if numbers: + return numbers[-1].replace(",", "") + + return None + + +def math_reward_function( + completions: list[str], + expected_answers: list[str], + group_size: int = 4, +) -> torch.Tensor: + """ + Reward function for math problems (e.g., GSM8K). + + Gives high reward for correct answers, low for incorrect. + + Args: + completions: List of completion strings + expected_answers: List of expected answers (one per prompt, repeated for group_size) + group_size: Number of samples per prompt + + Returns: + rewards: [batch] + """ + rewards = [] + + for idx, completion in enumerate(completions): + # Map completion index to prompt index + prompt_idx = idx // group_size + expected = expected_answers[prompt_idx].strip().lower() + + # Extract answer from completion + predicted = extract_numeric_answer(completion) + + if predicted is None: + # No valid answer found + reward = 0.0 + elif predicted.lower() == expected: + # Correct answer + reward = 1.0 + else: + # Wrong answer + reward = 0.0 + + rewards.append(reward) + + return torch.tensor(rewards, dtype=torch.float32) + + +def load_gsm8k_dataset(split: str = "train", num_samples: int = 100): + """ + Load GSM8K dataset from HuggingFace. + + Args: + split: Dataset split ("train" or "test") + num_samples: Number of samples to load + + Returns: + prompts: List of problem prompts + answers: List of expected answers (numeric strings) + """ + try: + from datasets import load_dataset + + dataset = load_dataset("openai/gsm8k", "main", split=split) + + prompts = [] + answers = [] + + for i, item in enumerate(dataset): + if i >= num_samples: + break + + question = item["question"] + answer = item["answer"] + + # Extract the final numeric answer from the answer field + # GSM8K answers are like "some explanation\n#### 42" + answer_num = extract_numeric_answer(answer) + if answer_num is None: + continue + + # Format prompt for the model + prompt = f"Question: {question}\nAnswer:" + + prompts.append(prompt) + answers.append(answer_num) + + return prompts, answers + + except ImportError: + print("⚠ datasets library not installed. Install with: pip install datasets") + return None, None + except Exception as e: + print(f"⚠ Failed to load GSM8K dataset: {e}") + return None, None + + +def trivial_reward_function( + completions: list[str], + tokenizer=None, + expected_answers: list[str] | None = None, + group_size: int = 4, +) -> torch.Tensor: + """ + Reward function based on correctness and lowercase preference. + + Penalizes non-English characters to keep output in English. + Rewards correct answers to factual questions. + Penalizes capital letters to encourage lowercase output. + + Args: + completions: List of completion strings + tokenizer: Tokenizer to count tokens + expected_answers: List of expected answers (one per prompt, repeated for group_size) + group_size: Number of samples per prompt + + Returns: + rewards: [batch] + """ + batch_size = len(completions) + rewards = [] + + for idx, completion in enumerate(completions): + # Start with base reward of 1.0 + reward = 1.0 + + total_chars = len(completion) + if total_chars == 0: + rewards.append(0.0) + continue + + # Penalty for non-English characters (keep it in English) + # Count non-ASCII characters + non_ascii_count = sum(1 for c in completion if ord(c) > 127) + non_ascii_ratio = non_ascii_count / total_chars + # Strong penalty if >10% non-ASCII + if non_ascii_ratio > 0.1: + reward *= 0.1 # 10x penalty + + # Penalty for capital letters (encourage lowercase) + uppercase_count = sum(1 for c in completion if c.isupper()) + uppercase_ratio = uppercase_count / total_chars + # Apply penalty proportional to uppercase ratio + # 0% uppercase = no penalty (1.0x) + # 100% uppercase = strong penalty (0.1x) + # Linear interpolation: penalty = 1.0 - 0.9 * uppercase_ratio + uppercase_penalty = 1.0 - 0.9 * uppercase_ratio + reward *= uppercase_penalty + + # Bonus for correct answers + if expected_answers is not None: + # Map completion index to prompt index + prompt_idx = idx // group_size + expected_answer = expected_answers[prompt_idx].lower() + completion_lower = completion.lower() + + # Check if answer is in completion + if expected_answer in completion_lower: + reward *= 2.0 # 2x bonus for correct answer + else: + reward *= 0.5 # Penalty for wrong answer + + rewards.append(reward) + + rewards = torch.tensor(rewards, dtype=torch.float32) + + return rewards + + +def compute_grpo_advantages( + rewards: torch.Tensor, group_size: int = 4, beta: float = 0.1 +) -> torch.Tensor: + """ + Compute advantages using GRPO-style exponential weighting. + + GRPO uses exponential advantages within groups which can be numerically + unstable without bitwise determinism. Small differences in reward computation + can lead to drastically different exp(reward/beta) values. + + This implementation uses the proper GRPO formulation: + advantage_i = exp(reward_i / beta) / Z - 1 + where Z = mean(exp(reward_j / beta)) for j in group + + Args: + rewards: [batch] + group_size: Number of samples per prompt (batch must be divisible by this) + beta: Temperature parameter for exponential weighting (lower = more unstable) + + Returns: + advantages: [batch] + """ + batch_size = rewards.shape[0] + assert ( + batch_size % group_size == 0 + ), f"Batch size {batch_size} must be divisible by group_size {group_size}" + + num_groups = batch_size // group_size + rewards_grouped = rewards.view(num_groups, group_size) + + # GRPO exponential advantages: exp(reward / beta) + # This is numerically unstable and will explode without bitwise invariance! + exp_rewards = torch.exp(rewards_grouped / beta) + + # Normalize by group mean (this is where instability shows up) + group_mean_exp = exp_rewards.mean(dim=1, keepdim=True) + + # Advantage = normalized_exp - 1 + advantages_grouped = exp_rewards / group_mean_exp - 1.0 + + # Flatten back + advantages = advantages_grouped.view(-1) + + return advantages + + +def compute_grpo_advantages_stable( + rewards: torch.Tensor, group_size: int = 4 +) -> torch.Tensor: + """ + Compute advantages using simple mean-centering (stable fallback). + + This is a simplified version that just uses mean-centering within groups. + Use this if you want stable training without bitwise invariance. + + Args: + rewards: [batch] + group_size: Number of samples per prompt (batch must be divisible by this) + + Returns: + advantages: [batch] + """ + batch_size = rewards.shape[0] + assert ( + batch_size % group_size == 0 + ), f"Batch size {batch_size} must be divisible by group_size {group_size}" + + num_groups = batch_size // group_size + rewards_grouped = rewards.view(num_groups, group_size) + + # Compute advantages: reward - group_mean + group_means = rewards_grouped.mean(dim=1, keepdim=True) + advantages_grouped = rewards_grouped - group_means + + # Flatten back + advantages = advantages_grouped.view(-1) + + return advantages + + +def policy_gradient_loss( + log_probs: torch.Tensor, advantages: torch.Tensor +) -> torch.Tensor: + """ + Compute policy gradient loss. + + L = -E[log π(a|s) * A(s,a)] + + Args: + log_probs: [batch, seq_len] - Log probs of generated tokens + advantages: [batch] - Advantages for each sample + + Returns: + loss: scalar + """ + # Sum log probs across sequence for each sample + total_log_probs = log_probs.sum(dim=1) # [batch] + + # Policy gradient: -log_prob * advantage + pg_loss = -(total_log_probs * advantages).mean() + + return pg_loss + + +def compute_policy_gradient_loss_vllm( + model: torch.nn.Module, + vllm_token_ids: list[list[int]], + vllm_token_log_probs: list[list[float]], + prompt_token_ids: list[list[int]], + advantages: torch.Tensor, + kl_coef: float = 0.1, + ppo_clip_eps: float = 0.2, + entropy_coef: float = 0.01, +) -> tuple[torch.Tensor, dict]: + """ + Compute PPO policy gradient loss by re-evaluating completions under current policy. + + Args: + model: Current policy model + vllm_token_ids: Generated token IDs for each completion + vllm_token_log_probs: Per-token log probs from vLLM (reference) + prompt_token_ids: Prompt token IDs for each completion + advantages: [batch] - Advantages for each sample + kl_coef: KL divergence penalty coefficient + ppo_clip_eps: PPO clipping epsilon + entropy_coef: Entropy bonus coefficient + + Returns: + loss: Total loss (PG + entropy + KL) + metrics: Training metrics dict (includes per-token logprob deltas) + """ + device = next(model.parameters()).device + advantages = advantages.to(device) + + # Compute reference log probs from per-token values + # Use PyTorch's sum() to match the reduction order used for total_log_probs + # This ensures exactly zero KL divergence with batch invariance + ref_log_probs = torch.stack( + [ + torch.tensor(lps, dtype=torch.float32, device=device).sum() + for lps in vllm_token_log_probs + ] + ) + + # Compute log probs under current policy (WITH GRADIENTS) + batch_token_log_probs = [] + batch_total_log_probs = [] + + # Track per-token differences for the first sample + first_sample_deltas = [] + + for idx, (prompt_toks, gen_toks, vllm_toks_lp) in enumerate( + zip(prompt_token_ids, vllm_token_ids, vllm_token_log_probs) + ): + # Concatenate prompt + generated tokens + full_sequence = prompt_toks + gen_toks + full_tensor = torch.tensor( + full_sequence, dtype=torch.long, device=device + ).unsqueeze(0) + + # Forward pass + logits = model(full_tensor) + # Use F.log_softmax which is overridden by batch_invariant mode for determinism + # Convert to float32 to match vLLM's sampler behavior (use .to() to preserve gradients) + log_probs = F.log_softmax(logits[:, :-1, :].to(torch.float32), dim=-1) + target_tokens = full_tensor[:, 1:] + + # Extract log probs for generated tokens only + prompt_len = len(prompt_toks) + gen_start_idx = prompt_len - 1 + gen_end_idx = gen_start_idx + len(gen_toks) + + gen_token_logprobs = log_probs[0, gen_start_idx:gen_end_idx, :] + gen_token_ids = target_tokens[0, gen_start_idx:gen_end_idx] + token_lps = gen_token_logprobs.gather(1, gen_token_ids.unsqueeze(-1)).squeeze( + -1 + ) + + batch_token_log_probs.append(token_lps) + batch_total_log_probs.append(token_lps.sum()) + + # For the first sample, store raw tensors for bitwise comparison + if idx == 0: + # Keep bfloat16 tensors for bitwise comparison + titan_lps_bf16 = token_lps.detach().cpu() # Keep as bfloat16 + titan_lps_f32 = ( + token_lps.detach().cpu().float() + ) # Convert to float32 for display + + for token_id, vllm_lp, titan_lp_bf16, titan_lp_f32 in zip( + gen_toks, vllm_toks_lp, titan_lps_bf16, titan_lps_f32 + ): + first_sample_deltas.append( + { + "token_id": token_id, + "vllm_logprob": vllm_lp, + "titan_logprob_bf16": titan_lp_bf16, + "titan_logprob_f32": titan_lp_f32.item(), + } + ) + + total_log_probs = torch.stack(batch_total_log_probs) + + # Verify bitwise determinism between vLLM and TorchTitan + if first_sample_deltas: + vllm_lps_f32 = torch.tensor( + [d["vllm_logprob"] for d in first_sample_deltas], dtype=torch.float32 + ) + titan_lps_f32 = torch.tensor( + [d["titan_logprob_f32"] for d in first_sample_deltas], dtype=torch.float32 + ) + + bitwise_identical = torch.equal(vllm_lps_f32, titan_lps_f32) + + if bitwise_identical: + print( + f" ✓ vLLM-TorchTitan bitwise determinism verified: {len(first_sample_deltas)} tokens match exactly" + ) + else: + num_different = (vllm_lps_f32 != titan_lps_f32).sum().item() + deltas = (vllm_lps_f32 - titan_lps_f32).abs() + max_delta = deltas.max().item() + avg_delta = deltas.mean().item() + print( + f" ⚠ vLLM-TorchTitan logprobs differ: {num_different}/{len(first_sample_deltas)} tokens" + ) + print(f" Max delta: {max_delta:.6e}, Avg delta: {avg_delta:.6e}") + print( + f" vLLM logprobs: {[f'{lp:.10f}' for lp in vllm_lps_f32[:5].tolist()]}" + ) + print( + f" TorchTitan logprobs: {[f'{lp:.10f}' for lp in titan_lps_f32[:5].tolist()]}" + ) + + # PPO clipped objective + log_ratio = total_log_probs - ref_log_probs + ratio = torch.exp(log_ratio) + unclipped_loss = ratio * advantages + clipped_ratio = torch.clamp(ratio, 1 - ppo_clip_eps, 1 + ppo_clip_eps) + clipped_loss = clipped_ratio * advantages + pg_loss = -torch.min(unclipped_loss, clipped_loss).mean() + + # Entropy bonus + all_token_log_probs = torch.cat(batch_token_log_probs) + entropy = -all_token_log_probs.mean() + entropy_bonus = -entropy_coef * entropy + + # KL divergence penalty + kl_div = (ratio - 1 - log_ratio).mean() + + # Total loss + total_loss = pg_loss + entropy_bonus + kl_coef * kl_div + + metrics = { + "pg_loss": pg_loss.item(), + "entropy": entropy.item(), + "kl_div": kl_div.item(), + "ratio_mean": ratio.mean().item(), + "ratio_clipped_frac": (torch.abs(ratio - clipped_ratio) > 1e-6) + .float() + .mean() + .item(), + "per_token_deltas": first_sample_deltas, # Per-token logprob differences for first sample + } + + return total_loss, metrics + + +def rl_update_step( + model, + tokenizer, + vllm_engine: VLLMRolloutEngine, + prompt_texts: list[str], + optimizer: torch.optim.Optimizer, + expected_answers: list[str] | None = None, + group_size: int = 8, + max_new_tokens: int = 20, + temperature: float = 1.0, + use_vllm_compat: bool = True, + num_rollout_batches: int = 1, + reward_fn=None, + grpo_beta: float = 0.1, + use_stable_grpo: bool = False, +) -> dict: + """ + Perform one RL update step using vLLM for rollouts. + + Args: + model: Policy model (TorchTitan) + tokenizer: Tokenizer + vllm_engine: Persistent vLLM engine + prompt_texts: List of prompt strings + optimizer: Optimizer + expected_answers: List of expected answers for each prompt + group_size: Number of samples per prompt for GRPO + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + use_vllm_compat: Whether to use vLLM-compatible model + num_rollout_batches: Number of rollout batches per update (more rollouts = more samples) + reward_fn: Reward function (defaults to trivial_reward_function) + grpo_beta: Beta parameter for GRPO exponential weighting (lower = more unstable) + use_stable_grpo: If True, use stable GRPO (mean-centering) instead of exponential + + Returns: + metrics: Dict of training metrics + """ + # Default reward function + if reward_fn is None: + reward_fn = trivial_reward_function + + # Update vLLM weights from current policy (only once per update) + titan_state = model.state_dict() + vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + vllm_engine.update_weights(vllm_compat_state) + + # Accumulate gradients over multiple rollout batches + optimizer.zero_grad() + + all_completions = [] + all_rewards = [] + all_advantages = [] + total_loss = 0.0 + batch_metrics = [] + + for batch_idx in range(num_rollout_batches): + # Generate samples using vLLM + ( + completions, + vllm_log_probs, + vllm_token_ids, + vllm_token_log_probs, + prompt_token_ids, + ) = vllm_engine.generate( + prompt_texts, + max_new_tokens, + temperature, + n_samples_per_prompt=group_size, + ) + + # Compute rewards using provided reward function + if reward_fn == trivial_reward_function: + rewards = reward_fn(completions, tokenizer, expected_answers, group_size) + elif reward_fn == math_reward_function: + rewards = reward_fn(completions, expected_answers, group_size) + else: + rewards = reward_fn(completions, expected_answers, group_size) + + # Normalize rewards for stability (mean=0, std=1) + reward_mean = rewards.mean() + reward_std = rewards.std() + if reward_std > 1e-8: + rewards_normalized = (rewards - reward_mean) / reward_std + else: + rewards_normalized = rewards - reward_mean + + # Compute advantages using GRPO + if use_stable_grpo: + advantages = compute_grpo_advantages_stable(rewards_normalized, group_size) + else: + advantages = compute_grpo_advantages( + rewards_normalized, group_size, beta=grpo_beta + ) + + # Compute loss using current policy + loss, loss_metrics = compute_policy_gradient_loss_vllm( + model, + vllm_token_ids, + vllm_token_log_probs, + prompt_token_ids, + advantages, + kl_coef=0.1, + ) + + # Accumulate loss (will be averaged later) + loss = loss / num_rollout_batches + loss.backward() + total_loss += loss.item() + + # Track metrics + all_completions.extend(completions[:2]) # Sample 2 from each batch + all_rewards.append(reward_mean.item()) + all_advantages.append(advantages.mean().item()) + batch_metrics.append(loss_metrics) + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + + # Update weights + optimizer.step() + + # Aggregate metrics across batches + avg_reward = sum(all_rewards) / len(all_rewards) + avg_advantage = sum(all_advantages) / len(all_advantages) + + # Use metrics from last batch for detailed stats + final_metrics = batch_metrics[-1] + + # Return aggregated metrics + metrics = { + "loss": total_loss, + "reward_mean": avg_reward, + "reward_std": batch_metrics[-1].get("reward_std", 0.0), + "advantage_mean": avg_advantage, + "advantage_std": batch_metrics[-1].get("advantage_std", 0.0), + "sample_completions": all_completions[:2], # First 2 for inspection + "num_rollout_batches": num_rollout_batches, + "total_samples": len(prompt_texts) * group_size * num_rollout_batches, + **final_metrics, # Include final batch metrics + } + + return metrics + + +def compute_weight_deltas(model: torch.nn.Module, initial_state: dict) -> dict: + """ + Compute weight changes from initial state based on magnitude (L2 norm). + + Args: + model: Current model + initial_state: Initial model state dict + + Returns: + Dictionary of weight delta statistics by module + """ + deltas = {} + module_stats = {} + + with torch.no_grad(): + current_state = model.state_dict() + + for name, current_param in current_state.items(): + if name not in initial_state: + continue + + # Move current param to CPU to compare with initial (avoid GPU OOM) + current_param_cpu = current_param.cpu() + initial_param = initial_state[name] + delta = current_param_cpu - initial_param + + # Extract module name (e.g., "layers.0.attention.wq" -> "layers.0") + parts = name.split(".") + if len(parts) >= 2: + module_name = ".".join(parts[:2]) + else: + module_name = parts[0] + + # Compute magnitude (L2 norm) of change + delta_norm = torch.linalg.vector_norm(delta).item() + param_norm = torch.linalg.vector_norm(current_param_cpu).item() + + # Relative change: ||delta|| / ||param|| + relative_change = delta_norm / (param_norm + 1e-8) + + # Accumulate module-level stats + if module_name not in module_stats: + module_stats[module_name] = {"norms": [], "relative": []} + + module_stats[module_name]["norms"].append(delta_norm) + module_stats[module_name]["relative"].append(relative_change) + + # Average module-level stats + for module_name, stats in module_stats.items(): + deltas[f"weight_delta/{module_name}/magnitude"] = sum(stats["norms"]) / len( + stats["norms"] + ) + deltas[f"weight_delta/{module_name}/relative_change"] = sum( + stats["relative"] + ) / len(stats["relative"]) + + return deltas + + +def main(): + """Simple RL training loop using vLLM for fast rollouts.""" + + # ========== Config ========== + model_name = "Qwen/Qwen3-1.7B" # HuggingFace model name + cache_dir = "./models" + output_dir = "./converted" + + # Training config + group_size = 8 # Samples per prompt for GRPO (increased from 4) + num_rollout_batches = 2 # Multiple rollout batches per update (NEW!) + num_steps = 100 + learning_rate = 1e-5 + + # GRPO config + use_stable_grpo = ( + False # Set to True for stable training, False to test bitwise invariance + ) + grpo_beta = 0.1 # Lower = more unstable (will explode without bitwise invariance!) + + # Dataset config + use_real_dataset = ( + True # Set to True to use GSM8K dataset (requires: pip install datasets) + ) + num_dataset_samples = 10 # Number of prompts from dataset + + # Check if batch invariance is enabled + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + + use_vllm_compat = vllm_is_batch_invariant() + + if use_vllm_compat: + print("✓ Batch invariance detected - using vLLM-compatible model") + # Add backward pass support to vLLM's batch_invariant mode + print(" Adding gradient support to vLLM's batch_invariant mode...") + from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + ) + + enable_batch_invariant_backward_mode() + else: + print("⚠ Batch invariance NOT detected - using standard model") + if not use_stable_grpo: + print( + " WARNING: Exponential GRPO may be unstable without bitwise invariance!" + ) + + # Download and convert model + print("=" * 80) + print(f"Setting up model: {model_name}") + print("=" * 80) + titan_checkpoint_path, model_path = download_and_convert_model( + model_name, cache_dir, output_dir + ) + + # Load TorchTitan model for training + print("\nLoading TorchTitan model for training...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = load_model( + titan_checkpoint_path, model_path, use_vllm_compat=use_vllm_compat + ) + model = model.to(device) + model.train() + + # Save initial weights for delta computation (on CPU to save GPU memory) + print("Saving initial weights for tracking...") + initial_state = { + name: param.clone().cpu() for name, param in model.state_dict().items() + } + + # Initialize persistent vLLM engine for rollouts + print("\nInitializing vLLM engine for rollouts...") + vllm_engine = VLLMRolloutEngine(model_path) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load dataset + print("\n" + "=" * 80) + print("Dataset Configuration") + print("=" * 80) + + if use_real_dataset: + print(f"Attempting to load GSM8K dataset ({num_dataset_samples} samples)...") + prompt_texts, expected_answers = load_gsm8k_dataset( + split="train", num_samples=num_dataset_samples + ) + + if prompt_texts is None or len(prompt_texts) == 0: + print("⚠ Failed to load dataset, falling back to default prompts") + use_real_dataset = False + + if not use_real_dataset: + # Fallback: simple prompts with verifiable answers + print("Using default prompts (factual questions)") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] + + # Select reward function + reward_fn = math_reward_function if use_real_dataset else trivial_reward_function + + print(f"Loaded {len(prompt_texts)} prompts") + print(f"Reward function: {reward_fn.__name__}") + print(f"First prompt: {prompt_texts[0][:80]}...") + + # Optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) + + # TensorBoard writer + writer = SummaryWriter("./outputs/rl_training") + print("\n" + "=" * 80) + print("TensorBoard logging enabled at: ./outputs/rl_training") + print("=" * 80) + + # Training loop + print(f"\nStarting RL training for {num_steps} steps...") + print(f" Prompts: {len(prompt_texts)}") + print(f" Samples per prompt: {group_size}") + print(f" Rollout batches per update: {num_rollout_batches}") + print( + f" Total samples per update: {len(prompt_texts) * group_size * num_rollout_batches}" + ) + print( + f" GRPO mode: {'Stable (mean-centering)' if use_stable_grpo else f'Exponential (beta={grpo_beta})'}" + ) + print("=" * 80) + + for step in range(num_steps): + metrics = rl_update_step( + model, + tokenizer, + vllm_engine, + prompt_texts, + optimizer, + expected_answers=expected_answers, + group_size=group_size, + max_new_tokens=20 if not use_real_dataset else 100, + temperature=1.0, + use_vllm_compat=use_vllm_compat, + num_rollout_batches=num_rollout_batches, + reward_fn=reward_fn, + grpo_beta=grpo_beta, + use_stable_grpo=use_stable_grpo, + ) + + # Compute weight deltas from initial state + weight_deltas = compute_weight_deltas(model, initial_state) + + # Log to TensorBoard + writer.add_scalar("rl/loss", metrics["loss"], step) + writer.add_scalar("rl/pg_loss", metrics["pg_loss"], step) + writer.add_scalar("rl/kl_div", metrics["kl_div"], step) + writer.add_scalar("rl/entropy", metrics["entropy"], step) + writer.add_scalar("rl/ratio_mean", metrics["ratio_mean"], step) + writer.add_scalar("rl/ratio_clipped_frac", metrics["ratio_clipped_frac"], step) + writer.add_scalar("rl/reward_mean", metrics["reward_mean"], step) + writer.add_scalar("rl/reward_std", metrics.get("reward_std", 0.0), step) + writer.add_scalar("rl/advantage_mean", metrics["advantage_mean"], step) + writer.add_scalar("rl/advantage_std", metrics.get("advantage_std", 0.0), step) + writer.add_scalar("rl/total_samples", metrics["total_samples"], step) + + # Log weight deltas + for key, value in weight_deltas.items(): + writer.add_scalar(key, value, step) + + print( + f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " + f"Reward: {metrics['reward_mean']:+.3f} | " + f"Samples: {metrics['total_samples']}" + ) + print(f" Sample: {metrics['sample_completions'][0][:80]}...") + + # Check for NaN/Inf (sign of instability) + if not torch.isfinite(torch.tensor(metrics["loss"])): + print("\n" + "!" * 80) + print("ERROR: Loss is NaN/Inf! Training diverged.") + print( + "This likely means the exponential GRPO is unstable without bitwise invariance." + ) + print("Try setting use_stable_grpo=True or enabling batch invariance mode.") + print("!" * 80) + break + + print("\n" + "=" * 80) + print("Training complete!") + print("View TensorBoard: tensorboard --logdir=./outputs/rl_training") + print("=" * 80) + + # Cleanup + writer.close() + del vllm_engine + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py new file mode 100644 index 0000000000..138d5fc95e --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for deterministic vLLM RL experiment. +""" diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py new file mode 100644 index 0000000000..3ed9604d10 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test batch_invariant_backward module to ensure it works correctly. +""" + +import torch + +from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + disable_batch_invariant_backward_mode, + enable_batch_invariant_backward_mode, + linear_batch_invariant_backward, + mm_batch_invariant_backward, +) + + +def test_mm_backward(): + """Test matrix multiplication with backward.""" + print("\n" + "=" * 80) + print("Testing mm_batch_invariant_backward...") + print("=" * 80) + + # Enable mode + from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode + + disable_batch_invariant_mode() + enable_batch_invariant_backward_mode() + + # Create test tensors + a = torch.randn(4, 8, device="cuda", dtype=torch.bfloat16, requires_grad=True) + b = torch.randn(8, 16, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Forward pass + c = mm_batch_invariant_backward(a, b) + print("Forward pass: a @ b = c") + print(f" a shape: {a.shape}, b shape: {b.shape}, c shape: {c.shape}") + + # Backward pass + loss = c.sum() + loss.backward() + + print("Backward pass successful!") + print(f" grad_a shape: {a.grad.shape if a.grad is not None else None}") + print(f" grad_b shape: {b.grad.shape if b.grad is not None else None}") + + assert a.grad is not None, "grad_a should not be None" + assert b.grad is not None, "grad_b should not be None" + print("✅ mm_backward test passed!") + + disable_batch_invariant_backward_mode() + + +def test_linear_backward(): + """Test linear layer with backward.""" + print("\n" + "=" * 80) + print("Testing linear_batch_invariant_backward...") + print("=" * 80) + + # Enable mode + from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode + + disable_batch_invariant_mode() + enable_batch_invariant_backward_mode() + + # Create test tensors (3D input for realistic case) + input = torch.randn( + 2, 10, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + weight = torch.randn( + 128, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + bias = torch.randn(128, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Forward pass + output = linear_batch_invariant_backward(input, weight, bias) + print("Forward pass: linear(input, weight, bias) = output") + print(f" input shape: {input.shape}") + print(f" weight shape: {weight.shape}") + print(f" output shape: {output.shape}") + + # Backward pass + loss = output.sum() + loss.backward() + + print("Backward pass successful!") + print(f" grad_input shape: {input.grad.shape if input.grad is not None else None}") + print( + f" grad_weight shape: {weight.grad.shape if weight.grad is not None else None}" + ) + print(f" grad_bias shape: {bias.grad.shape if bias.grad is not None else None}") + + assert input.grad is not None, "grad_input should not be None" + assert weight.grad is not None, "grad_weight should not be None" + assert bias.grad is not None, "grad_bias should not be None" + print("✅ linear_backward test passed!") + + disable_batch_invariant_backward_mode() + + +def test_deterministic_forward(): + """Test that forward passes are deterministic.""" + print("\n" + "=" * 80) + print("Testing deterministic forward passes...") + print("=" * 80) + + # Enable mode + from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode + + disable_batch_invariant_mode() + enable_batch_invariant_backward_mode() + + # Create test tensors + a = torch.randn(32, 64, device="cuda", dtype=torch.bfloat16) + b = torch.randn(64, 128, device="cuda", dtype=torch.bfloat16) + + # Run forward pass twice + c1 = mm_batch_invariant_backward(a, b) + c2 = mm_batch_invariant_backward(a, b) + + # Check if results are identical + diff = (c1 - c2).abs().max().item() + print(f"Forward pass 1 result: {c1[0, :5]}") + print(f"Forward pass 2 result: {c2[0, :5]}") + print(f"Max absolute difference: {diff}") + + assert diff == 0.0, f"Forward passes should be deterministic, but diff={diff}" + print("✅ Deterministic forward test passed!") + + disable_batch_invariant_backward_mode() + + +def main(): + """Run all tests.""" + print("=" * 80) + print("Testing batch_invariant_backward module") + print("=" * 80) + + if not torch.cuda.is_available(): + print("❌ CUDA not available, skipping tests") + return + + try: + test_mm_backward() + test_linear_backward() + test_deterministic_forward() + + print("\n" + "=" * 80) + print("✅ All tests passed!") + print("=" * 80) + + except Exception as e: + print(f"\n❌ Test failed with error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py new file mode 100644 index 0000000000..8d0ac3133e --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test if batch_invariant operations are EXACTLY deterministic. + +This runs the same operation multiple times and checks if results are bit-for-bit identical. +""" + +import torch +from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode + +from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + enable_batch_invariant_backward_mode, +) + +print("Enabling batch_invariant_backward mode...") +disable_batch_invariant_mode() +enable_batch_invariant_backward_mode() + + +def test_mm_exact_determinism(): + """Test if mm is exactly deterministic.""" + print("\n" + "=" * 80) + print("Testing mm exact determinism") + print("=" * 80) + + # Create random inputs + torch.manual_seed(42) + a = torch.randn(32, 64, device="cuda", dtype=torch.bfloat16) + b = torch.randn(64, 128, device="cuda", dtype=torch.bfloat16) + + # Run multiple times + results = [] + for i in range(5): + c = torch.mm(a, b) + results.append(c.clone()) + print( + f"Run {i + 1}: mean={c.float().mean().item():.10f}, " + f"std={c.float().std().item():.10f}" + ) + + # Check if all results are identical + all_same = True + for i in range(1, len(results)): + if not torch.equal(results[0], results[i]): + diff = (results[0] - results[i]).abs().max().item() + print(f" ✗ Run 1 vs Run {i + 1}: MAX DIFF = {diff}") + all_same = False + + if all_same: + print(" ✓ All runs produce IDENTICAL results (bit-for-bit)") + else: + print(" ✗ Results differ across runs") + + return all_same + + +def test_flash_attention_determinism(): + """Test if Flash Attention is exactly deterministic.""" + print("\n" + "=" * 80) + print("Testing Flash Attention exact determinism") + print("=" * 80) + + from vllm.vllm_flash_attn import flash_attn_varlen_func + + # Create random inputs + torch.manual_seed(42) + batch_size = 2 + seq_len = 16 + num_heads = 32 + head_dim = 128 + + q = torch.randn( + batch_size * seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + k = torch.randn( + batch_size * seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + v = torch.randn( + batch_size * seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device="cuda" + ) + + # Run multiple times + results = [] + for i in range(5): + output = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=1.0 / (head_dim**0.5), + causal=True, + num_splits=1, # Deterministic mode + ) + results.append(output.clone()) + print( + f"Run {i + 1}: mean={output.float().mean().item():.10f}, " + f"std={output.float().std().item():.10f}" + ) + + # Check if all results are identical + all_same = True + for i in range(1, len(results)): + if not torch.equal(results[0], results[i]): + diff = (results[0] - results[i]).abs().max().item() + print(f" ✗ Run 1 vs Run {i + 1}: MAX DIFF = {diff}") + all_same = False + + if all_same: + print(" ✓ All runs produce IDENTICAL results (bit-for-bit)") + else: + print(" ✗ Results differ across runs") + + return all_same + + +def test_log_softmax_determinism(): + """Test if log_softmax is exactly deterministic.""" + print("\n" + "=" * 80) + print("Testing log_softmax exact determinism") + print("=" * 80) + + import torch.nn.functional as F + + # Create random inputs + torch.manual_seed(42) + x = torch.randn(32, 151936, device="cuda", dtype=torch.bfloat16) + + # Run multiple times + results = [] + for i in range(5): + # Convert to float32 before log_softmax (as we do in training) + output = F.log_softmax(x.float(), dim=-1) + results.append(output.clone()) + print( + f"Run {i + 1}: mean={output.mean().item():.10f}, " + f"std={output.std().item():.10f}" + ) + + # Check if all results are identical + all_same = True + for i in range(1, len(results)): + if not torch.equal(results[0], results[i]): + diff = (results[0] - results[i]).abs().max().item() + print(f" ✗ Run 1 vs Run {i + 1}: MAX DIFF = {diff}") + all_same = False + + if all_same: + print(" ✓ All runs produce IDENTICAL results (bit-for-bit)") + else: + print(" ✗ Results differ across runs") + + return all_same + + +def main(): + print("Testing exact determinism of operations") + print("=" * 80) + + results = {} + results["mm"] = test_mm_exact_determinism() + results["flash_attention"] = test_flash_attention_determinism() + results["log_softmax"] = test_log_softmax_determinism() + + print("\n" + "=" * 80) + print("Summary") + print("=" * 80) + for op, is_deterministic in results.items(): + status = "✓ DETERMINISTIC" if is_deterministic else "✗ NON-DETERMINISTIC" + print(f"{op:<20}: {status}") + + if all(results.values()): + print("\n✓ All operations are exactly deterministic!") + else: + print( + "\n✗ Some operations are not deterministic - this explains the vLLM/TorchTitan difference" + ) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/deterministic_vllm_rl/weights/README.md new file mode 100644 index 0000000000..abaccd165b --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/weights/README.md @@ -0,0 +1,120 @@ +# Weight Converter: vLLM ↔ TorchTitan + +Minimal weight conversion between vLLM/HuggingFace and TorchTitan formats for Qwen3-1.7B. + +## Files + +- **`weight_converter.py`**: Core conversion functions +- **`test_converter.py`**: Download & test script (weight comparison) +- **`test_forward_passes.py`**: Forward pass test (logits comparison) + +## Quick Start + +### 1. Install dependencies + +```bash +pip install torch safetensors huggingface_hub transformers +``` + +### 2. Run weight conversion test (downloads Qwen3-1.7B automatically) + +```bash +python test_converter.py +``` + +This will: +1. Download Qwen3-1.7B from HuggingFace (~3.5GB) +2. Convert to TorchTitan format +3. Convert back to vLLM format (round-trip test) +4. Verify all weights match + +### 3. Run forward pass test (validates conversion accuracy) + +```bash +python test_forward_passes.py +``` + +This will: +1. Download Qwen3-1.7B (if not already cached) +2. Convert weights to TorchTitan format +3. Run forward pass on both vLLM (via transformers) and TorchTitan +4. Compare logits to verify conversion accuracy +5. Report differences and top token predictions + +### 4. Use custom directories + +```bash +python test_converter.py ./custom_cache ./custom_output +python test_forward_passes.py ./custom_cache ./custom_output +``` + +## Manual Usage + +### Convert vLLM to TorchTitan + +```python +from weight_converter import vllm_to_torchtitan +from safetensors.torch import save_file + +# Convert +titan_weights = vllm_to_torchtitan("path/to/vllm/model") + +# Save +save_file(titan_weights, "qwen3_torchtitan.safetensors") +``` + +### Convert TorchTitan to vLLM + +```python +from weight_converter import torchtitan_to_vllm +from safetensors.torch import load_file, save_file + +# Load TorchTitan weights +titan_weights = load_file("qwen3_torchtitan.safetensors") + +# Convert +vllm_weights = torchtitan_to_vllm(titan_weights) + +# Save +save_file(vllm_weights, "qwen3_vllm.safetensors") +``` + +## Command-line Interface + +```bash +# vLLM → TorchTitan +python weight_converter.py vllm_to_titan + +# TorchTitan → vLLM +python weight_converter.py titan_to_vllm +``` + +## Key Differences + +### Weight Name Mappings + +| vLLM/HuggingFace | TorchTitan | +|------------------|------------| +| `model.embed_tokens.weight` | `tok_embeddings.weight` | +| `model.layers.{N}.self_attn.q_proj.weight` | `layers.{N}.attention.wq.weight` | +| `model.layers.{N}.self_attn.k_proj.weight` | `layers.{N}.attention.wk.weight` | +| `model.layers.{N}.self_attn.v_proj.weight` | `layers.{N}.attention.wv.weight` | +| `model.layers.{N}.self_attn.o_proj.weight` | `layers.{N}.attention.wo.weight` | +| `model.layers.{N}.mlp.gate_proj.weight` | `layers.{N}.feed_forward.w1.weight` | +| `model.layers.{N}.mlp.up_proj.weight` | `layers.{N}.feed_forward.w3.weight` | +| `model.layers.{N}.mlp.down_proj.weight` | `layers.{N}.feed_forward.w2.weight` | +| `model.norm.weight` | `norm.weight` | +| `lm_head.weight` | `output.weight` | + +### Notes + +- Rotary embedding frequencies (`rotary_emb.inv_freq`) are computed on-the-fly in TorchTitan, so they're skipped during conversion +- Both formats support `.safetensors` and `.bin` (PyTorch) files +- Qwen3 uses q_norm/k_norm for attention normalization, which are preserved in both formats + +## Model Support + +Currently tested with: +- **Qwen3-1.7B** ✅ + +Should work with other Qwen3 models with same architecture. diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py new file mode 100644 index 0000000000..c64a48688e --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Weight conversion utilities for vLLM and TorchTitan.""" + +from .converter import torchtitan_to_vllm, vllm_to_torchtitan + +__all__ = ["vllm_to_torchtitan", "torchtitan_to_vllm"] diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py new file mode 100644 index 0000000000..092af9c37d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Minimal weight converter between vLLM and TorchTitan formats for Qwen3-1.7B. + +This script provides bidirectional weight conversion: +- vllm_to_torchtitan: Load weights from vLLM format and convert to TorchTitan format +- torchtitan_to_vllm: Load weights from TorchTitan format and convert to vLLM format +""" + +from pathlib import Path + +import torch +from safetensors.torch import load_file, save_file + + +# Weight name mapping from HuggingFace/vLLM to TorchTitan +VLLM_TO_TITAN_MAP = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention weights + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", + # MLP weights + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Layer norms + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # Final norm and output + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", +} + + +def vllm_to_torchtitan( + vllm_path_or_state: str | dict[str, torch.Tensor] +) -> dict[str, torch.Tensor]: + """ + Load weights from vLLM format (HuggingFace) and convert to TorchTitan format. + + Args: + vllm_path_or_state: Either a path to vLLM model directory (contains .safetensors or .bin files) + OR a vLLM state dict + + Returns: + Dictionary with TorchTitan-formatted state dict + """ + # Check if input is a state dict or a path + if isinstance(vllm_path_or_state, dict): + vllm_state = vllm_path_or_state + print(f"Using provided vLLM state dict with {len(vllm_state)} weights") + else: + vllm_path = Path(vllm_path_or_state) + + # Load weights from vLLM format (try safetensors first, then .bin) + vllm_state = {} + safetensor_files = sorted(vllm_path.glob("*.safetensors")) + + if safetensor_files: + print(f"Loading {len(safetensor_files)} safetensors files...") + for st_file in safetensor_files: + if "index" not in st_file.name: # Skip index files + vllm_state.update(load_file(str(st_file))) + else: + # Fallback to .bin files + bin_files = sorted(vllm_path.glob("*.bin")) + print(f"Loading {len(bin_files)} .bin files...") + for bin_file in bin_files: + state = torch.load(bin_file, map_location="cpu", weights_only=True) + vllm_state.update(state) + + print(f"Loaded {len(vllm_state)} weights from vLLM format") + + # Convert to TorchTitan format + titan_state = {} + + for vllm_key, tensor in vllm_state.items(): + # Skip rotary embedding frequencies (not needed in TorchTitan) + if "rotary_emb.inv_freq" in vllm_key: + continue + + # Check if it's a layer-specific weight + if "layers." in vllm_key: + # Extract layer number + parts = vllm_key.split(".") + layer_idx = parts[2] + + # Create abstract key with placeholder + abstract_vllm_key = vllm_key.replace(f".{layer_idx}.", ".{}.") + + # Look up in mapping + if abstract_vllm_key in VLLM_TO_TITAN_MAP: + abstract_titan_key = VLLM_TO_TITAN_MAP[abstract_vllm_key] + titan_key = abstract_titan_key.format(layer_idx) + titan_state[titan_key] = tensor + else: + print(f"Warning: No mapping found for {vllm_key}") + else: + # Non-layer weight + if vllm_key in VLLM_TO_TITAN_MAP: + titan_key = VLLM_TO_TITAN_MAP[vllm_key] + titan_state[titan_key] = tensor + else: + print(f"Warning: No mapping found for {vllm_key}") + + print(f"Converted to {len(titan_state)} TorchTitan weights") + return titan_state + + +def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Convert weights from TorchTitan format to vLLM format (HuggingFace). + + Args: + titan_state: TorchTitan state dict (can be in vLLM-compat format with gate_up_proj) + + Returns: + Dictionary with vLLM/HuggingFace-formatted state dict + """ + # Create reverse mapping + titan_to_vllm_map = {v: k for k, v in VLLM_TO_TITAN_MAP.items()} + + vllm_state = {} + + for titan_key, tensor in titan_state.items(): + # Handle merged gate_up_proj (vLLM-compat format) -> split into gate_proj + up_proj + if ".feed_forward.gate_up_proj.weight" in titan_key: + # Split into gate_proj (first half) and up_proj (second half) + hidden_dim = tensor.shape[0] // 2 + # CLONE to avoid aliasing - these are views into the original tensor + gate_weight = tensor[:hidden_dim].clone() + up_weight = tensor[hidden_dim:].clone() + + # Extract layer number + parts = titan_key.split(".") + layer_idx = parts[1] + + # Create vLLM keys + gate_key = f"model.layers.{layer_idx}.mlp.gate_proj.weight" + up_key = f"model.layers.{layer_idx}.mlp.up_proj.weight" + + vllm_state[gate_key] = gate_weight + vllm_state[up_key] = up_weight + continue + + # Handle down_proj (vLLM-compat format) + if ".feed_forward.down_proj.weight" in titan_key: + parts = titan_key.split(".") + layer_idx = parts[1] + vllm_key = f"model.layers.{layer_idx}.mlp.down_proj.weight" + # CLONE to avoid aliasing + vllm_state[vllm_key] = tensor.clone() + continue + + # Check if it's a layer-specific weight + if "layers." in titan_key: + # Extract layer number + parts = titan_key.split(".") + layer_idx = parts[1] + + # Create abstract key with placeholder + abstract_titan_key = titan_key.replace(f".{layer_idx}.", ".{}.") + + # Look up in reverse mapping + if abstract_titan_key in titan_to_vllm_map: + abstract_vllm_key = titan_to_vllm_map[abstract_titan_key] + vllm_key = abstract_vllm_key.format(layer_idx) + # CLONE to avoid aliasing + vllm_state[vllm_key] = tensor.clone() + else: + print(f"Warning: No mapping found for {titan_key}") + else: + # Non-layer weight + if titan_key in titan_to_vllm_map: + vllm_key = titan_to_vllm_map[titan_key] + # CLONE to avoid aliasing + vllm_state[vllm_key] = tensor.clone() + else: + print(f"Warning: No mapping found for {titan_key}") + + print(f"Converted to {len(vllm_state)} vLLM weights") + return vllm_state + + +# Example usage +if __name__ == "__main__": + import sys + + if len(sys.argv) < 3: + print("Usage:") + print(" Convert vLLM to TorchTitan:") + print( + " python weight_converter.py vllm_to_titan " + ) + print(" Convert TorchTitan to vLLM:") + print( + " python weight_converter.py titan_to_vllm " + ) + sys.exit(1) + + mode = sys.argv[1] + input_path = sys.argv[2] + output_path = sys.argv[3] + + if mode == "vllm_to_titan": + # Convert vLLM to TorchTitan + titan_state = vllm_to_torchtitan(input_path) + + # Save as safetensors + print(f"Saving to {output_path}...") + save_file(titan_state, output_path) + print("Done!") + + elif mode == "titan_to_vllm": + # Load TorchTitan checkpoint + print(f"Loading TorchTitan checkpoint from {input_path}...") + titan_state = load_file(input_path) + + # Convert to vLLM + vllm_state = torchtitan_to_vllm(titan_state) + + # Save as safetensors + print(f"Saving to {output_path}...") + save_file(vllm_state, output_path) + print("Done!") + + else: + print(f"Unknown mode: {mode}") + print("Use 'vllm_to_titan' or 'titan_to_vllm'") + sys.exit(1) diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py new file mode 100644 index 0000000000..ed3293af78 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Weight conversion utilities for Qwen3VLLMCompatModel. + +Converts between: +- TorchTitan format (separate w1, w2, w3 in FFN) +- vLLM compat format (merged gate_up_proj = [w1; w3]) +""" + +from typing import Dict + +import torch + + +def torchtitan_to_vllm_compat( + torchtitan_state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert TorchTitan Qwen3 state dict to vLLM-compatible format. + + Main changes: + - Merge feed_forward.w1 and feed_forward.w3 into feed_forward.gate_up_proj + - Rename feed_forward.w2 to feed_forward.down_proj + """ + vllm_compat_state_dict = {} + + for key, tensor in torchtitan_state_dict.items(): + # Handle FFN weight merging + if ".feed_forward.w1.weight" in key: + # Get the corresponding w3 weight + w3_key = key.replace(".feed_forward.w1.weight", ".feed_forward.w3.weight") + w1_weight = tensor + w3_weight = torchtitan_state_dict[w3_key] + + # Merge: gate_up_proj = [w1; w3] (concatenate along output dim) + # torch.cat creates a new tensor, so no need to clone + gate_up_weight = torch.cat([w1_weight, w3_weight], dim=0) + + # Save with new key + new_key = key.replace( + ".feed_forward.w1.weight", ".feed_forward.gate_up_proj.weight" + ) + vllm_compat_state_dict[new_key] = gate_up_weight + + elif ".feed_forward.w3.weight" in key: + # Skip w3, already merged with w1 + continue + + elif ".feed_forward.w2.weight" in key: + # Rename w2 to down_proj + new_key = key.replace( + ".feed_forward.w2.weight", ".feed_forward.down_proj.weight" + ) + # CLONE to avoid aliasing + vllm_compat_state_dict[new_key] = tensor.clone() + + else: + # Keep all other weights as-is + # CLONE to avoid aliasing + vllm_compat_state_dict[key] = tensor.clone() + + return vllm_compat_state_dict + + +def vllm_compat_to_torchtitan( + vllm_compat_state_dict: Dict[str, torch.Tensor] +) -> Dict[str, torch.Tensor]: + """ + Convert vLLM-compatible state dict back to TorchTitan format. + + Main changes: + - Split feed_forward.gate_up_proj into separate w1 and w3 + - Rename feed_forward.down_proj to w2 + """ + torchtitan_state_dict = {} + + for key, tensor in vllm_compat_state_dict.items(): + # Handle FFN weight splitting + if ".feed_forward.gate_up_proj.weight" in key: + # Split into w1 (first half) and w3 (second half) + hidden_dim = tensor.shape[0] // 2 + # CLONE to avoid aliasing - these are views into the original tensor + w1_weight = tensor[:hidden_dim].clone() + w3_weight = tensor[hidden_dim:].clone() + + # Save with new keys + w1_key = key.replace( + ".feed_forward.gate_up_proj.weight", ".feed_forward.w1.weight" + ) + w3_key = key.replace( + ".feed_forward.gate_up_proj.weight", ".feed_forward.w3.weight" + ) + torchtitan_state_dict[w1_key] = w1_weight + torchtitan_state_dict[w3_key] = w3_weight + + elif ".feed_forward.down_proj.weight" in key: + # Rename down_proj to w2 + new_key = key.replace( + ".feed_forward.down_proj.weight", ".feed_forward.w2.weight" + ) + # CLONE to avoid aliasing + torchtitan_state_dict[new_key] = tensor.clone() + + else: + # Keep all other weights as-is + # CLONE to avoid aliasing + torchtitan_state_dict[key] = tensor.clone() + + return torchtitan_state_dict + + +if __name__ == "__main__": + # Test conversion + from safetensors.torch import load_file, save_file + + print("Loading TorchTitan checkpoint...") + checkpoint_path = "./converted/qwen3_torchtitan.safetensors" + titan_state = load_file(checkpoint_path) + + print(f"\nOriginal TorchTitan state dict has {len(titan_state)} keys") + print("Sample keys:") + for i, key in enumerate(list(titan_state.keys())[:10]): + print(f" {key}") + + print("\nConverting to vLLM-compat format...") + vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + + print(f"\nvLLM-compat state dict has {len(vllm_compat_state)} keys") + print("Sample keys:") + for i, key in enumerate(list(vllm_compat_state.keys())[:10]): + print(f" {key}") + + # Save converted checkpoint + output_path = "./converted/qwen3_vllm_compat.safetensors" + print(f"\nSaving to {output_path}...") + save_file(vllm_compat_state, output_path) + print("Done!")