Skip to content

Commit 35bffe9

Browse files
bwastiteja-raofegin
authored
Add deterministic RL training experiment with vLLM (#1975)
This experiment provides a complete framework for bitwise-deterministic reinforcement learning training that combines vLLM for fast rollouts and TorchTitan for training with gradients. Key features: - Bitwise-deterministic forward and backward passes - vLLM-compatible Qwen3 model with merged projections - Custom Flash Attention with gradient support - Gradient support for vLLM's batch-invariant operations - Complete RL training loop with GRPO-style advantages - Comprehensive test suite for determinism verification Components: - models/attention.py: VLLMCompatibleFlashAttention - models/qwen3/model_vllm_compat.py: vLLM-compatible Qwen3 model - batch_invariant_backward.py: Gradient support for vLLM operations - simple_rl.py: End-to-end RL training loop - tests/: Test suite for backward passes and determinism --------- Co-authored-by: Teja <[email protected]> Co-authored-by: Chien-Chin Huang <[email protected]>
1 parent 5ecc871 commit 35bffe9

File tree

15 files changed

+3368
-0
lines changed

15 files changed

+3368
-0
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Deterministic RL Training with vLLM
2+
3+
This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs.
4+
5+
## Overview
6+
7+
RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations.
8+
9+
The implementation:
10+
1. Uses vLLM's batch-invariant kernels for forward passes
11+
2. Implements custom backward passes for gradient computation
12+
3. Provides weight conversion utilities between TorchTitan and vLLM formats
13+
14+
### Features
15+
16+
- Bitwise determinism: Same inputs produce identical outputs across runs
17+
- Gradient support: Backward passes through vLLM operations
18+
- Weight conversion: Utilities to convert between model formats
19+
20+
Note: Currently supports single-device training only.
21+
22+
## Architecture
23+
24+
### Components
25+
26+
1. `models/attention.py`: VLLMCompatibleFlashAttention
27+
- Uses vLLM's Flash Attention for forward pass
28+
- Implements custom backward pass for gradient computation
29+
- Uses `num_splits=1` for deterministic behavior
30+
31+
2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel
32+
- Qwen3 model with merged gate/up projections matching vLLM format
33+
- Uses VLLMRMSNorm with gradient support
34+
35+
3. `batch_invariant_backward.py`: Backward passes for vLLM operations
36+
- Registers gradients for vLLM's batch-invariant operations
37+
- Supports matmul, linear, and RMSNorm
38+
- Patches Flash Attention for autograd
39+
40+
4. `weights_vllm_compat.py`: Weight conversion utilities
41+
- Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj)
42+
- Provides bidirectional conversion functions
43+
44+
5. `simple_rl.py`: RL training loop
45+
- Generates rollouts using vLLM engine
46+
- Computes advantages using GRPO-style ranking
47+
- Updates policy using PPO
48+
49+
## Installation
50+
51+
### Prerequisites
52+
53+
```bash
54+
# Install vLLM with deterministic support
55+
pip install vllm
56+
57+
# Install TorchTitan (from the repository root)
58+
pip install -e .
59+
60+
# Install additional dependencies
61+
pip install transformers safetensors huggingface_hub tensorboard
62+
```
63+
64+
### Enable Batch Invariance
65+
66+
Initialize vLLM's batch-invariant mode before training:
67+
68+
```python
69+
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
70+
init_batch_invariance()
71+
```
72+
73+
## Usage
74+
75+
### Quick Start
76+
77+
```python
78+
import torch
79+
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
80+
from torchtitan.experiments.deterministic_vllm_rl import (
81+
enable_batch_invariant_backward_mode,
82+
Qwen3VLLMCompatModel,
83+
)
84+
85+
# 1. Enable deterministic mode
86+
init_batch_invariance()
87+
enable_batch_invariant_backward_mode()
88+
89+
# 2. Load model
90+
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
91+
model_args = Qwen3ModelArgs(
92+
dim=2048,
93+
n_layers=24,
94+
n_heads=16,
95+
n_kv_heads=2,
96+
vocab_size=151936,
97+
)
98+
model = Qwen3VLLMCompatModel(model_args)
99+
100+
# 3. Forward pass (deterministic)
101+
input_ids = torch.randint(0, 151936, (2, 128), device='cuda')
102+
logits = model(input_ids)
103+
104+
# 4. Backward pass
105+
loss = logits.sum()
106+
loss.backward()
107+
```
108+
109+
### Full RL Training
110+
111+
Run the RL training loop:
112+
113+
```bash
114+
VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl
115+
```
116+
117+
This will:
118+
1. Download Qwen3-1.7B from HuggingFace
119+
2. Initialize vLLM engine for rollouts
120+
3. Generate samples for training prompts
121+
4. Compute rewards and advantages
122+
5. Update the policy using PPO
123+
6. Log metrics to TensorBoard
124+
125+
View training progress:
126+
```bash
127+
tensorboard --logdir=./outputs/rl_training
128+
```
129+
130+
## How It Works
131+
132+
### Deterministic Forward Pass
133+
134+
vLLM's batch-invariant mode makes operations deterministic:
135+
136+
```python
137+
# These operations are deterministic when batch_invariance is enabled
138+
y = torch.matmul(a, b) # Uses vLLM's deterministic matmul
139+
output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA
140+
```
141+
142+
### Backward Pass with Gradients
143+
144+
Custom backward passes:
145+
1. Re-compute attention weights deterministically
146+
2. Use standard chain rule for gradients
147+
3. Apply gradients through vLLM's deterministic operations
148+
149+
```python
150+
class FlashAttnWithBackward(torch.autograd.Function):
151+
@staticmethod
152+
def forward(ctx, q, k, v, ...):
153+
# Use vLLM's forward implementation
154+
return flash_attn_varlen_func(q, k, v, num_splits=1, ...)
155+
156+
@staticmethod
157+
def backward(ctx, grad_output):
158+
# Compute gradients deterministically
159+
# (re-compute attention weights and apply chain rule)
160+
return grad_q, grad_k, grad_v, ...
161+
```
162+
163+
### Bitwise Determinism Verification
164+
165+
The training loop compares logprobs from vLLM and TorchTitan:
166+
167+
```python
168+
# During training, compare logprobs
169+
vllm_logprobs = [from vLLM rollout]
170+
titan_logprobs = [from TorchTitan forward pass]
171+
172+
assert torch.equal(vllm_logprobs, titan_logprobs)
173+
```
174+
175+
## Testing
176+
177+
Run the test suite:
178+
179+
```bash
180+
cd torchtitan/experiments/deterministic_vllm_rl/tests
181+
182+
# Test backward passes
183+
python test_batch_invariant_backward.py
184+
185+
# Test determinism
186+
python test_exact_determinism.py
187+
```
188+
189+
## Technical Details
190+
191+
### Why Determinism Matters for RL
192+
193+
RL training steps:
194+
1. Generate rollouts by sampling from the policy
195+
2. Compute rewards based on the samples
196+
3. Update the policy using gradients
197+
198+
If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities.
199+
200+
This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise.
201+
202+
### Performance
203+
204+
- Rollout speed: Uses vLLM's optimized kernels
205+
- Training speed: Similar to standard TorchTitan
206+
- Memory: Saves activations for custom backward passes
207+
208+
### Limitations
209+
210+
1. Custom backward requires uniform sequence lengths
211+
2. Only causal attention is supported
212+
3. Requires NVIDIA GPUs with Flash Attention support
213+
214+
## Project Structure
215+
216+
```
217+
deterministic_vllm_rl/
218+
├── README.md # Documentation
219+
├── __init__.py # Package initialization
220+
├── batch_invariant_backward.py # Backward passes for vLLM ops
221+
├── weights_vllm_compat.py # Weight conversion utilities
222+
├── simple_rl.py # RL training loop
223+
├── models/
224+
│ ├── __init__.py
225+
│ ├── attention.py # VLLMCompatibleFlashAttention
226+
│ └── qwen3/
227+
│ ├── __init__.py
228+
│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model
229+
├── weights/
230+
│ ├── __init__.py
231+
│ ├── converter.py # Weight conversion script
232+
│ └── README.md # Weight conversion documentation
233+
└── tests/
234+
├── __init__.py
235+
├── test_batch_invariant_backward.py # Test backward passes
236+
└── test_exact_determinism.py # Test determinism
237+
```
238+
239+
## TODO
240+
241+
- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory.
242+
- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible.
243+
- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP.
244+
245+
## Contributing
246+
247+
This experiment is part of TorchTitan. To contribute:
248+
249+
1. Test your changes with `pytest tests/`
250+
2. Verify bitwise determinism is maintained
251+
3. Update this README if adding new features
252+
253+
## References
254+
255+
- [vLLM Documentation](https://docs.vllm.ai/)
256+
- [Flash Attention Paper](https://arxiv.org/abs/2205.14135)
257+
- [PPO Algorithm](https://arxiv.org/abs/1707.06347)
258+
- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300)
259+
260+
## License
261+
262+
This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Deterministic RL training with vLLM experiment.
9+
10+
This experiment provides tools for bitwise-deterministic reinforcement learning
11+
training using vLLM for fast rollouts and TorchTitan for training.
12+
13+
Key components:
14+
- VLLMCompatibleFlashAttention: Flash attention with custom backward pass
15+
- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections
16+
- batch_invariant_backward: Gradient support for vLLM's deterministic operations
17+
- simple_rl: End-to-end RL training loop
18+
"""
19+
20+
from .batch_invariant_backward import (
21+
enable_batch_invariant_backward_mode,
22+
rms_norm_with_gradients,
23+
silu_and_mul_with_gradients,
24+
)
25+
from .models import VLLMCompatibleFlashAttention
26+
from .models.qwen3 import Qwen3VLLMCompatModel
27+
28+
__all__ = [
29+
"VLLMCompatibleFlashAttention",
30+
"Qwen3VLLMCompatModel",
31+
"enable_batch_invariant_backward_mode",
32+
"rms_norm_with_gradients",
33+
"silu_and_mul_with_gradients",
34+
]

0 commit comments

Comments
 (0)