Skip to content

Commit 2823b41

Browse files
committed
Add missing weight conversion files and fix imports
This commit adds the missing weight conversion utilities that are required by Bram's base commit (simple_rl.py imports them but they were missing): - weights_vllm_compat.py: Converts between TorchTitan and vLLM-compat formats (merges/splits gate_up_proj for FFN layers) - weights/converter.py: Converts between vLLM HuggingFace and TorchTitan formats - weights/__init__.py: Package init - weights/README.md: Documentation for weight converters Import fixes: - simple_rl.py: Use local models.qwen3 instead of torchtitan.models.qwen3.model - model_vllm_compat.py: Import VLLMCompatibleFlashAttention from local ..attention and Qwen3ModelArgs from torchtitan.models.qwen3.model.args
1 parent 5fc60f8 commit 2823b41

File tree

6 files changed

+478
-3
lines changed

6 files changed

+478
-3
lines changed

torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
create_attention_mask,
1818
get_causal_mask_mod,
1919
get_document_mask_mod,
20-
VLLMCompatibleFlashAttention,
2120
)
2221
from torchtitan.protocols.model import AttentionMasksType
2322
from torchtitan.protocols.train_spec import ModelProtocol
2423

25-
from .args import Qwen3ModelArgs
24+
# Import from local experiment's models
25+
from ..attention import VLLMCompatibleFlashAttention
26+
27+
# Import from main torchtitan
28+
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
2629

2730
# Import vLLM's exact operations for bitwise determinism
2831
from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul

torchtitan/experiments/deterministic_vllm_rl/simple_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr
307307

308308
if use_vllm_compat:
309309
# Create and load model (using vLLM-compat for bitwise determinism)
310-
from torchtitan.models.qwen3.model.model_vllm_compat import Qwen3VLLMCompatModel
310+
from models.qwen3 import Qwen3VLLMCompatModel
311311
model = Qwen3VLLMCompatModel(model_args)
312312
# Convert to vLLM-compat format (merged gate_up_proj, down_proj)
313313
vllm_compat_state = torchtitan_to_vllm_compat(state_dict)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Weight Converter: vLLM ↔ TorchTitan
2+
3+
Minimal weight conversion between vLLM/HuggingFace and TorchTitan formats for Qwen3-1.7B.
4+
5+
## Files
6+
7+
- **`weight_converter.py`**: Core conversion functions
8+
- **`test_converter.py`**: Download & test script (weight comparison)
9+
- **`test_forward_passes.py`**: Forward pass test (logits comparison)
10+
11+
## Quick Start
12+
13+
### 1. Install dependencies
14+
15+
```bash
16+
pip install torch safetensors huggingface_hub transformers
17+
```
18+
19+
### 2. Run weight conversion test (downloads Qwen3-1.7B automatically)
20+
21+
```bash
22+
python test_converter.py
23+
```
24+
25+
This will:
26+
1. Download Qwen3-1.7B from HuggingFace (~3.5GB)
27+
2. Convert to TorchTitan format
28+
3. Convert back to vLLM format (round-trip test)
29+
4. Verify all weights match
30+
31+
### 3. Run forward pass test (validates conversion accuracy)
32+
33+
```bash
34+
python test_forward_passes.py
35+
```
36+
37+
This will:
38+
1. Download Qwen3-1.7B (if not already cached)
39+
2. Convert weights to TorchTitan format
40+
3. Run forward pass on both vLLM (via transformers) and TorchTitan
41+
4. Compare logits to verify conversion accuracy
42+
5. Report differences and top token predictions
43+
44+
### 4. Use custom directories
45+
46+
```bash
47+
python test_converter.py ./custom_cache ./custom_output
48+
python test_forward_passes.py ./custom_cache ./custom_output
49+
```
50+
51+
## Manual Usage
52+
53+
### Convert vLLM to TorchTitan
54+
55+
```python
56+
from weight_converter import vllm_to_torchtitan
57+
from safetensors.torch import save_file
58+
59+
# Convert
60+
titan_weights = vllm_to_torchtitan("path/to/vllm/model")
61+
62+
# Save
63+
save_file(titan_weights, "qwen3_torchtitan.safetensors")
64+
```
65+
66+
### Convert TorchTitan to vLLM
67+
68+
```python
69+
from weight_converter import torchtitan_to_vllm
70+
from safetensors.torch import load_file, save_file
71+
72+
# Load TorchTitan weights
73+
titan_weights = load_file("qwen3_torchtitan.safetensors")
74+
75+
# Convert
76+
vllm_weights = torchtitan_to_vllm(titan_weights)
77+
78+
# Save
79+
save_file(vllm_weights, "qwen3_vllm.safetensors")
80+
```
81+
82+
## Command-line Interface
83+
84+
```bash
85+
# vLLM → TorchTitan
86+
python weight_converter.py vllm_to_titan <vllm_path> <output.safetensors>
87+
88+
# TorchTitan → vLLM
89+
python weight_converter.py titan_to_vllm <titan_checkpoint.safetensors> <output.safetensors>
90+
```
91+
92+
## Key Differences
93+
94+
### Weight Name Mappings
95+
96+
| vLLM/HuggingFace | TorchTitan |
97+
|------------------|------------|
98+
| `model.embed_tokens.weight` | `tok_embeddings.weight` |
99+
| `model.layers.{N}.self_attn.q_proj.weight` | `layers.{N}.attention.wq.weight` |
100+
| `model.layers.{N}.self_attn.k_proj.weight` | `layers.{N}.attention.wk.weight` |
101+
| `model.layers.{N}.self_attn.v_proj.weight` | `layers.{N}.attention.wv.weight` |
102+
| `model.layers.{N}.self_attn.o_proj.weight` | `layers.{N}.attention.wo.weight` |
103+
| `model.layers.{N}.mlp.gate_proj.weight` | `layers.{N}.feed_forward.w1.weight` |
104+
| `model.layers.{N}.mlp.up_proj.weight` | `layers.{N}.feed_forward.w3.weight` |
105+
| `model.layers.{N}.mlp.down_proj.weight` | `layers.{N}.feed_forward.w2.weight` |
106+
| `model.norm.weight` | `norm.weight` |
107+
| `lm_head.weight` | `output.weight` |
108+
109+
### Notes
110+
111+
- Rotary embedding frequencies (`rotary_emb.inv_freq`) are computed on-the-fly in TorchTitan, so they're skipped during conversion
112+
- Both formats support `.safetensors` and `.bin` (PyTorch) files
113+
- Qwen3 uses q_norm/k_norm for attention normalization, which are preserved in both formats
114+
115+
## Model Support
116+
117+
Currently tested with:
118+
- **Qwen3-1.7B**
119+
120+
Should work with other Qwen3 models with same architecture.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Weight conversion utilities for vLLM and TorchTitan."""
2+
3+
from .converter import vllm_to_torchtitan, torchtitan_to_vllm
4+
5+
__all__ = ["vllm_to_torchtitan", "torchtitan_to_vllm"]
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""
2+
Minimal weight converter between vLLM and TorchTitan formats for Qwen3-1.7B.
3+
4+
This script provides bidirectional weight conversion:
5+
- vllm_to_torchtitan: Load weights from vLLM format and convert to TorchTitan format
6+
- torchtitan_to_vllm: Load weights from TorchTitan format and convert to vLLM format
7+
"""
8+
9+
import torch
10+
from safetensors.torch import load_file, save_file
11+
from pathlib import Path
12+
13+
14+
# Weight name mapping from HuggingFace/vLLM to TorchTitan
15+
VLLM_TO_TITAN_MAP = {
16+
"model.embed_tokens.weight": "tok_embeddings.weight",
17+
# Attention weights
18+
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
19+
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
20+
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
21+
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
22+
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
23+
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
24+
# MLP weights
25+
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
26+
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
27+
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
28+
# Layer norms
29+
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
30+
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
31+
# Final norm and output
32+
"model.norm.weight": "norm.weight",
33+
"lm_head.weight": "output.weight",
34+
}
35+
36+
37+
def vllm_to_torchtitan(vllm_path_or_state: str | dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
38+
"""
39+
Load weights from vLLM format (HuggingFace) and convert to TorchTitan format.
40+
41+
Args:
42+
vllm_path_or_state: Either a path to vLLM model directory (contains .safetensors or .bin files)
43+
OR a vLLM state dict
44+
45+
Returns:
46+
Dictionary with TorchTitan-formatted state dict
47+
"""
48+
# Check if input is a state dict or a path
49+
if isinstance(vllm_path_or_state, dict):
50+
vllm_state = vllm_path_or_state
51+
print(f"Using provided vLLM state dict with {len(vllm_state)} weights")
52+
else:
53+
vllm_path = Path(vllm_path_or_state)
54+
55+
# Load weights from vLLM format (try safetensors first, then .bin)
56+
vllm_state = {}
57+
safetensor_files = sorted(vllm_path.glob("*.safetensors"))
58+
59+
if safetensor_files:
60+
print(f"Loading {len(safetensor_files)} safetensors files...")
61+
for st_file in safetensor_files:
62+
if "index" not in st_file.name: # Skip index files
63+
vllm_state.update(load_file(str(st_file)))
64+
else:
65+
# Fallback to .bin files
66+
bin_files = sorted(vllm_path.glob("*.bin"))
67+
print(f"Loading {len(bin_files)} .bin files...")
68+
for bin_file in bin_files:
69+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
70+
vllm_state.update(state)
71+
72+
print(f"Loaded {len(vllm_state)} weights from vLLM format")
73+
74+
# Convert to TorchTitan format
75+
titan_state = {}
76+
77+
for vllm_key, tensor in vllm_state.items():
78+
# Skip rotary embedding frequencies (not needed in TorchTitan)
79+
if "rotary_emb.inv_freq" in vllm_key:
80+
continue
81+
82+
# Check if it's a layer-specific weight
83+
if "layers." in vllm_key:
84+
# Extract layer number
85+
parts = vllm_key.split(".")
86+
layer_idx = parts[2]
87+
88+
# Create abstract key with placeholder
89+
abstract_vllm_key = vllm_key.replace(f".{layer_idx}.", ".{}.")
90+
91+
# Look up in mapping
92+
if abstract_vllm_key in VLLM_TO_TITAN_MAP:
93+
abstract_titan_key = VLLM_TO_TITAN_MAP[abstract_vllm_key]
94+
titan_key = abstract_titan_key.format(layer_idx)
95+
titan_state[titan_key] = tensor
96+
else:
97+
print(f"Warning: No mapping found for {vllm_key}")
98+
else:
99+
# Non-layer weight
100+
if vllm_key in VLLM_TO_TITAN_MAP:
101+
titan_key = VLLM_TO_TITAN_MAP[vllm_key]
102+
titan_state[titan_key] = tensor
103+
else:
104+
print(f"Warning: No mapping found for {vllm_key}")
105+
106+
print(f"Converted to {len(titan_state)} TorchTitan weights")
107+
return titan_state
108+
109+
110+
def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
111+
"""
112+
Convert weights from TorchTitan format to vLLM format (HuggingFace).
113+
114+
Args:
115+
titan_state: TorchTitan state dict (can be in vLLM-compat format with gate_up_proj)
116+
117+
Returns:
118+
Dictionary with vLLM/HuggingFace-formatted state dict
119+
"""
120+
# Create reverse mapping
121+
titan_to_vllm_map = {v: k for k, v in VLLM_TO_TITAN_MAP.items()}
122+
123+
vllm_state = {}
124+
125+
for titan_key, tensor in titan_state.items():
126+
# Handle merged gate_up_proj (vLLM-compat format) -> split into gate_proj + up_proj
127+
if ".feed_forward.gate_up_proj.weight" in titan_key:
128+
# Split into gate_proj (first half) and up_proj (second half)
129+
hidden_dim = tensor.shape[0] // 2
130+
# CLONE to avoid aliasing - these are views into the original tensor
131+
gate_weight = tensor[:hidden_dim].clone()
132+
up_weight = tensor[hidden_dim:].clone()
133+
134+
# Extract layer number
135+
parts = titan_key.split(".")
136+
layer_idx = parts[1]
137+
138+
# Create vLLM keys
139+
gate_key = f"model.layers.{layer_idx}.mlp.gate_proj.weight"
140+
up_key = f"model.layers.{layer_idx}.mlp.up_proj.weight"
141+
142+
vllm_state[gate_key] = gate_weight
143+
vllm_state[up_key] = up_weight
144+
continue
145+
146+
# Handle down_proj (vLLM-compat format)
147+
if ".feed_forward.down_proj.weight" in titan_key:
148+
parts = titan_key.split(".")
149+
layer_idx = parts[1]
150+
vllm_key = f"model.layers.{layer_idx}.mlp.down_proj.weight"
151+
# CLONE to avoid aliasing
152+
vllm_state[vllm_key] = tensor.clone()
153+
continue
154+
155+
# Check if it's a layer-specific weight
156+
if "layers." in titan_key:
157+
# Extract layer number
158+
parts = titan_key.split(".")
159+
layer_idx = parts[1]
160+
161+
# Create abstract key with placeholder
162+
abstract_titan_key = titan_key.replace(f".{layer_idx}.", ".{}.")
163+
164+
# Look up in reverse mapping
165+
if abstract_titan_key in titan_to_vllm_map:
166+
abstract_vllm_key = titan_to_vllm_map[abstract_titan_key]
167+
vllm_key = abstract_vllm_key.format(layer_idx)
168+
# CLONE to avoid aliasing
169+
vllm_state[vllm_key] = tensor.clone()
170+
else:
171+
print(f"Warning: No mapping found for {titan_key}")
172+
else:
173+
# Non-layer weight
174+
if titan_key in titan_to_vllm_map:
175+
vllm_key = titan_to_vllm_map[titan_key]
176+
# CLONE to avoid aliasing
177+
vllm_state[vllm_key] = tensor.clone()
178+
else:
179+
print(f"Warning: No mapping found for {titan_key}")
180+
181+
print(f"Converted to {len(vllm_state)} vLLM weights")
182+
return vllm_state
183+
184+
185+
# Example usage
186+
if __name__ == "__main__":
187+
import sys
188+
189+
if len(sys.argv) < 3:
190+
print("Usage:")
191+
print(" Convert vLLM to TorchTitan:")
192+
print(" python weight_converter.py vllm_to_titan <vllm_model_path> <output_path>")
193+
print(" Convert TorchTitan to vLLM:")
194+
print(" python weight_converter.py titan_to_vllm <titan_checkpoint_path> <output_path>")
195+
sys.exit(1)
196+
197+
mode = sys.argv[1]
198+
input_path = sys.argv[2]
199+
output_path = sys.argv[3]
200+
201+
if mode == "vllm_to_titan":
202+
# Convert vLLM to TorchTitan
203+
titan_state = vllm_to_torchtitan(input_path)
204+
205+
# Save as safetensors
206+
print(f"Saving to {output_path}...")
207+
save_file(titan_state, output_path)
208+
print("Done!")
209+
210+
elif mode == "titan_to_vllm":
211+
# Load TorchTitan checkpoint
212+
print(f"Loading TorchTitan checkpoint from {input_path}...")
213+
titan_state = load_file(input_path)
214+
215+
# Convert to vLLM
216+
vllm_state = torchtitan_to_vllm(titan_state)
217+
218+
# Save as safetensors
219+
print(f"Saving to {output_path}...")
220+
save_file(vllm_state, output_path)
221+
print("Done!")
222+
223+
else:
224+
print(f"Unknown mode: {mode}")
225+
print("Use 'vllm_to_titan' or 'titan_to_vllm'")
226+
sys.exit(1)

0 commit comments

Comments
 (0)