Skip to content

Commit ae4ebd3

Browse files
committed
lint
1 parent 90bc444 commit ae4ebd3

File tree

9 files changed

+380
-192
lines changed

9 files changed

+380
-192
lines changed

torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
# Custom autograd Functions for vLLM operations
4343
# ============================================================================
4444

45+
4546
class SiluAndMulFunction(Function):
4647
"""
4748
Autograd function for vLLM's SiluAndMul activation.
@@ -86,7 +87,7 @@ def backward(ctx, grad_output):
8687
8788
where d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
8889
"""
89-
x, = ctx.saved_tensors
90+
(x,) = ctx.saved_tensors
9091

9192
# Split input into gate and up
9293
d = x.shape[-1] // 2
@@ -176,6 +177,7 @@ def backward(ctx, grad_output):
176177
# Backward operation implementations for autograd
177178
# ============================================================================
178179

180+
179181
def matmul_backward_impl(grad_output, self, other, output_mask):
180182
"""
181183
Backward pass for matmul: y = matmul(a, b)
@@ -324,8 +326,8 @@ def rms_norm_backward_impl(grad_output, input, weight, eps):
324326
# Registration
325327
# ============================================================================
326328

327-
_batch_invariant_backward_MODE = False
328-
_batch_invariant_backward_LIB = None
329+
_batch_invariant_backward_mode = False
330+
_batch_invariant_backward_lib = None
329331

330332

331333
def patch_batch_invariant_with_gradients():
@@ -335,28 +337,35 @@ def patch_batch_invariant_with_gradients():
335337
implementations by registering the backward operations. vLLM handles all the
336338
forward passes, we just add gradient support.
337339
"""
338-
global _batch_invariant_backward_MODE, _batch_invariant_backward_LIB
340+
global _batch_invariant_backward_mode, _batch_invariant_backward_lib
339341

340-
if _batch_invariant_backward_MODE:
342+
if _batch_invariant_backward_mode:
341343
return
342344

343345
# Get vLLM's batch_invariant library (already created by init_batch_invariance)
344346
from vllm.model_executor.layers import batch_invariant as vllm_bi
345347

346-
if not hasattr(vllm_bi, '_batch_invariant_LIB') or vllm_bi._batch_invariant_LIB is None:
348+
if (
349+
not hasattr(vllm_bi, "_batch_invariant_LIB")
350+
or vllm_bi._batch_invariant_LIB is None
351+
):
347352
raise RuntimeError(
348353
"vLLM's batch_invariant mode is not initialized. "
349354
"Call init_batch_invariance() first."
350355
)
351356

352357
# Use vLLM's existing library - don't destroy it!
353-
_batch_invariant_backward_LIB = vllm_bi._batch_invariant_LIB
358+
_batch_invariant_backward_lib = vllm_bi._batch_invariant_LIB
354359

355360
# Just add the backward operations - everything else is already handled by vLLM
356-
_batch_invariant_backward_LIB.impl("aten::matmul_backward", matmul_backward_impl, "CUDA")
357-
_batch_invariant_backward_LIB.impl("aten::linear_backward", linear_backward_impl, "CUDA")
361+
_batch_invariant_backward_lib.impl(
362+
"aten::matmul_backward", matmul_backward_impl, "CUDA"
363+
)
364+
_batch_invariant_backward_lib.impl(
365+
"aten::linear_backward", linear_backward_impl, "CUDA"
366+
)
358367

359-
_batch_invariant_backward_MODE = True
368+
_batch_invariant_backward_mode = True
360369

361370

362371
def enable_batch_invariant_backward_mode():
@@ -366,25 +375,28 @@ def enable_batch_invariant_backward_mode():
366375

367376
def disable_batch_invariant_backward_mode():
368377
"""Disable batch invariant backward mode."""
369-
global _batch_invariant_backward_MODE, _batch_invariant_backward_LIB
378+
global _batch_invariant_backward_mode, _batch_invariant_backward_lib
370379

371-
if _batch_invariant_backward_LIB is not None:
372-
_batch_invariant_backward_LIB._destroy()
380+
if _batch_invariant_backward_lib is not None:
381+
_batch_invariant_backward_lib._destroy()
373382

374-
_batch_invariant_backward_MODE = False
375-
_batch_invariant_backward_LIB = None
383+
_batch_invariant_backward_mode = False
384+
_batch_invariant_backward_lib = None
376385

377386

378387
def is_batch_invariant_backward_mode_enabled():
379388
"""Check if batch invariant backward mode is enabled."""
380-
return _batch_invariant_backward_MODE
389+
return _batch_invariant_backward_mode
381390

382391

383392
# ============================================================================
384393
# Public API for gradient-enabled vLLM operations
385394
# ============================================================================
386395

387-
def rms_norm_with_gradients(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
396+
397+
def rms_norm_with_gradients(
398+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
399+
) -> torch.Tensor:
388400
"""
389401
RMS normalization with gradient support.
390402

torchtitan/experiments/deterministic_vllm_rl/models/attention.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
class VLLMCompatibleFlashAttention(torch.nn.Module):
1616
"""Wrapper around FlashAttention as used by VLLM"""
17+
1718
def __init__(self) -> None:
1819
super().__init__()
1920
self.flash_attn_varlen_func = flash_attn_varlen_func
20-
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
2121
from vllm.attention.utils.fa_utils import get_flash_attn_version
22+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
23+
2224
self.vllm_is_batch_invariant = vllm_is_batch_invariant
2325
self.fa_version = get_flash_attn_version()
2426

@@ -51,17 +53,29 @@ def forward(
5153
# Create cumulative sequence lengths
5254
# cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len]
5355
cu_seqlens = torch.arange(
54-
0, (batch_size + 1) * seq_len, seq_len,
55-
dtype=torch.int32, device=q.device
56+
0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device
5657
)
5758

5859
# Wrap Flash Attention with manual backward pass
5960
class FlashAttnWithBackward(torch.autograd.Function):
6061
@staticmethod
61-
def forward(ctx, q, k, v, cu_seqlens, seq_len, scale, num_splits, flash_fn, fa_version):
62+
def forward(
63+
ctx,
64+
q,
65+
k,
66+
v,
67+
cu_seqlens,
68+
seq_len,
69+
scale,
70+
num_splits,
71+
flash_fn,
72+
fa_version,
73+
):
6274
# Call flash attention for forward (fast)
6375
output = flash_fn(
64-
q, k, v,
76+
q,
77+
k,
78+
v,
6579
cu_seqlens_q=cu_seqlens,
6680
cu_seqlens_k=cu_seqlens,
6781
max_seqlen_q=seq_len,
@@ -94,7 +108,9 @@ def backward(ctx, grad_output):
94108
k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim)
95109
v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim)
96110
out_batch = output.reshape(batch_size, seq_len, num_heads, head_dim)
97-
grad_out_batch = grad_output.reshape(batch_size, seq_len, num_heads, head_dim)
111+
grad_out_batch = grad_output.reshape(
112+
batch_size, seq_len, num_heads, head_dim
113+
)
98114

99115
# Transpose to (batch, num_heads, seq_len, head_dim)
100116
q_t = q_batch.transpose(1, 2)
@@ -108,11 +124,16 @@ def backward(ctx, grad_output):
108124
scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale
109125

110126
# Apply causal mask
111-
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool), diagonal=1)
112-
scores = scores.masked_fill(causal_mask, float('-inf'))
127+
causal_mask = torch.triu(
128+
torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool),
129+
diagonal=1,
130+
)
131+
scores = scores.masked_fill(causal_mask, float("-inf"))
113132

114133
# Softmax
115-
attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (B, H, N, N)
134+
attn_weights = torch.nn.functional.softmax(
135+
scores, dim=-1
136+
) # (B, H, N, N)
116137

117138
# Backward through attention
118139
# out = attn_weights @ v
@@ -140,18 +161,29 @@ def backward(ctx, grad_output):
140161
grad_k_t = torch.matmul(grad_scores.transpose(-2, -1), q_t)
141162

142163
# Transpose back and reshape to varlen format
143-
grad_q = grad_q_t.transpose(1, 2).reshape(total_tokens, num_heads, head_dim)
144-
grad_k = grad_k_t.transpose(1, 2).reshape(total_tokens, num_heads, head_dim)
145-
grad_v = grad_v_t.transpose(1, 2).reshape(total_tokens, num_heads, head_dim)
164+
grad_q = grad_q_t.transpose(1, 2).reshape(
165+
total_tokens, num_heads, head_dim
166+
)
167+
grad_k = grad_k_t.transpose(1, 2).reshape(
168+
total_tokens, num_heads, head_dim
169+
)
170+
grad_v = grad_v_t.transpose(1, 2).reshape(
171+
total_tokens, num_heads, head_dim
172+
)
146173

147174
return grad_q, grad_k, grad_v, None, None, None, None, None, None
148175

149176
# Call Flash Attention varlen with custom backward
150177
output_varlen = FlashAttnWithBackward.apply(
151-
q_varlen, k_varlen, v_varlen,
152-
cu_seqlens, seq_len, scale,
178+
q_varlen,
179+
k_varlen,
180+
v_varlen,
181+
cu_seqlens,
182+
seq_len,
183+
scale,
153184
1 if self.vllm_is_batch_invariant() else 0,
154-
self.flash_attn_varlen_func, self.fa_version,
185+
self.flash_attn_varlen_func,
186+
self.fa_version,
155187
)
156188

157189
# Convert back to batch format

torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212

1313
from torchtitan.components.tokenizer import BaseTokenizer
1414

15-
# Import from main torchtitan
16-
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
17-
from torchtitan.protocols.model import AttentionMasksType
18-
from torchtitan.protocols.train_spec import ModelProtocol
19-
2015
# Import gradient-enabled operations from experiment utilities
2116
from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import (
2217
rms_norm_with_gradients,
2318
silu_and_mul_with_gradients,
2419
)
2520

21+
# Import from main torchtitan
22+
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
23+
from torchtitan.protocols.model import AttentionMasksType
24+
from torchtitan.protocols.train_spec import ModelProtocol
25+
2626
# Import from local experiment's models
2727
from ..attention import VLLMCompatibleFlashAttention
2828

0 commit comments

Comments
 (0)