-
Notifications
You must be signed in to change notification settings - Fork 3.2k
try to support sgl-kernel in piecewise-cuda-graph #11716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
I can't reproduce it in a single script: Output: bbuf python3 /home/bbuf/reproduce_rope_triton_issue.py
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
INFO:__main__:✅ sgl_kernel is available
INFO:__main__:
📝 Entering torch.compile mode...
INFO:__main__: ✅ All CustomOps set to use sgl-kernel in torch.compile mode
INFO:__main__:
⚙️ Compiling model with torch.compile (backend=inductor)...
INFO:__main__:🔥 Running warm-up to trigger compilation...
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
torch._dynamo.utils.warn_once(msg)
INFO:__main__: ✅ Compilation complete
INFO:__main__:
📊 Profiling with torch.profiler to check actual kernels...
void flashinfer::norm::RMSNormKernel<8u, __nv_bfloat16>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, unsigned int, unsigned int, unsigned int, float, float)
nvjet_tst_128x32_64x10_4x1_v_bz_TNT
triton_poi_fused__to_copy_0
triton_poi_fused__to_copy_1
void flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<false, 128u, 8u, 16u, __nv_bfloat16, long>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, float*, long*, unsigned int, unsigned int, unsigned int, unsigned int, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long)
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<128, 128, 64, 4, cutlass::bfloat16_t> >, false, false, false, false, false, true, false, false>(pytorch_flash::Flash_fwd_params)
void flashinfer::norm::FusedAddRMSNormKernel<8u, __nv_bfloat16>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, unsigned int, unsigned int, unsigned int, float, float)import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(name)s:%(message)s')
logger = logging.getLogger(__name__)
# Check sgl_kernel availability
try:
from sgl_kernel import rmsnorm, fused_add_rmsnorm, apply_rope_with_cos_sin_cache_inplace
HAS_SGL_KERNEL = True
logger.info("✅ sgl_kernel is available")
except ImportError:
HAS_SGL_KERNEL = False
logger.error("❌ sgl_kernel not available - this script requires sgl_kernel to run")
exit(1)
# ============================================================================
# Simplified CustomOp
# ============================================================================
class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()
# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False
def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return
self._original_forward_method = self._forward_method
# For RMSNorm and RotaryEmbedding, keep using sgl-kernel implementations
# instead of falling back to forward_native, to avoid performance degradation
op_name = self.__class__.__name__
if any(name in op_name for name in ["RMSNorm", "RotaryEmbedding", "Llama3RotaryEmbedding"]):
# Keep the original forward method (forward_cuda with sgl-kernel)
# Don't switch to forward_native
pass
else:
self._forward_method = self.forward_native
self.is_torch_compile = True
def leave_torch_compile(self):
# Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return
self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def dispatch_forward(self):
return self.forward_cuda
# ============================================================================
# RMSNorm
# ============================================================================
class RMSNorm(CustomOp):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward_cuda(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
# Reshape to 2D for sgl_kernel
original_shape = x.shape
if x.ndim == 3:
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.eps)
if len(original_shape) == 3:
x = x.view(original_shape)
residual = residual.view(original_shape)
return x, residual
else:
out = rmsnorm(x, self.weight.data, self.eps)
if len(original_shape) == 3:
out = out.view(original_shape)
return out
def forward_native(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
if residual is not None:
x_added = x + residual
variance = x_added.pow(2).mean(-1, keepdim=True)
out = x_added * torch.rsqrt(variance + self.eps) * self.weight
return out, x_added
else:
variance = x.pow(2).mean(-1, keepdim=True)
out = x * torch.rsqrt(variance + self.eps) * self.weight
return out
# ============================================================================
# RotaryEmbedding
# ============================================================================
class RotaryEmbedding(CustomOp):
def __init__(self, head_size: int = 128, max_position: int = 4096, base: int = 10000, is_neox_style: bool = True):
super().__init__()
self.head_size = head_size
self.max_position = max_position
self.base = base
self.is_neox_style = is_neox_style
# Compute cos/sin cache (keep as float32 for sgl_kernel)
cache = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_cache", cache.float(), persistent=False)
def _compute_inv_freq(self, base):
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_size, 2, dtype=torch.float) / self.head_size))
return inv_freq
def _compute_cos_sin_cache(self):
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_cuda(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor):
# Ensure cos_sin_cache is float32
cos_sin_cache = self.cos_sin_cache
if cos_sin_cache.dtype != torch.float32:
cos_sin_cache = cos_sin_cache.float()
# Flatten positions
if positions.ndim > 1:
positions = positions.reshape(-1)
# Call sgl_kernel rope
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=cos_sin_cache,
is_neox=True,
)
return query, key
def forward_native(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor):
# Simplified PyTorch implementation
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
def rotate_neox(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
query_rot = query * cos + rotate_neox(query) * sin
key_rot = key * cos + rotate_neox(key) * sin
return query_rot, key_rot
# ============================================================================
# Llama3RotaryEmbedding (same as SGLang implementation)
# ============================================================================
class Llama3RotaryEmbedding(RotaryEmbedding):
"""Llama3 style rotary embedding with frequency scaling"""
def __init__(
self,
head_size: int = 128,
max_position: int = 131072,
base: int = 500000,
is_neox_style: bool = True,
scaling_factor: float = 8.0,
low_freq_factor: float = 1.0,
high_freq_factor: float = 4.0,
orig_max_position: int = 8192,
):
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(head_size, max_position, base, is_neox_style)
def _compute_inv_freq(self, base):
"""Llama3-specific inverse frequency computation with scaling"""
import math
inv_freqs = super()._compute_inv_freq(base)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
wave_len = 2 * math.pi / inv_freqs
if self.low_freq_factor != self.high_freq_factor:
smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
else:
smooth = 0
new_freqs = torch.where(
wave_len < high_freq_wavelen,
inv_freqs,
torch.where(
wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
),
)
return new_freqs
# ============================================================================
# Simple Test Model
# ============================================================================
class SimpleModel(nn.Module):
def __init__(self, hidden_size: int = 4096, num_heads: int = 32, head_size: int = 128, use_llama3_rope: bool = False):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = head_size
self.input_norm = RMSNorm(hidden_size)
# Choose RoPE type
if use_llama3_rope:
self.rotary_emb = Llama3RotaryEmbedding(
head_size=head_size,
max_position=131072,
base=500000,
is_neox_style=True,
scaling_factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
orig_max_position=8192,
)
else:
self.rotary_emb = RotaryEmbedding(head_size=head_size)
self.qkv_proj = nn.Linear(hidden_size, 3 * num_heads * head_size, bias=False)
self.output_norm = RMSNorm(hidden_size)
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, residual: Optional[torch.Tensor] = None):
# Input norm
if residual is None:
hidden_states = self.input_norm(hidden_states)
residual = hidden_states
else:
hidden_states, residual = self.input_norm(hidden_states, residual)
# QKV projection
batch_size, seq_len = hidden_states.shape[:2]
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_size)
q, k, v = qkv.unbind(dim=2)
# Apply RoPE
q = q.reshape(batch_size * seq_len, self.num_heads, self.head_size)
k = k.reshape(batch_size * seq_len, self.num_heads, self.head_size)
q, k = self.rotary_emb(positions.reshape(-1), q, k)
q = q.view(batch_size, seq_len, self.num_heads, self.head_size)
k = k.view(batch_size, seq_len, self.num_heads, self.head_size)
# Simplified attention (just for testing)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
).transpose(1, 2)
output = attn_output.reshape(batch_size, seq_len, -1)
# Output norm
output, residual = self.output_norm(output, residual)
return output, residual
# ============================================================================
# Main Test
# ============================================================================
def test_rope_implementation(use_llama3: bool, label: str):
"""Test a specific RoPE implementation"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
# Create model
model = SimpleModel(use_llama3_rope=use_llama3).to(device, dtype)
# Enter torch.compile mode
logger.info("\n📝 Entering torch.compile mode...")
model.input_norm.enter_torch_compile(num_tokens=16)
model.output_norm.enter_torch_compile(num_tokens=16)
model.rotary_emb.enter_torch_compile(num_tokens=16)
logger.info(" ✅ All CustomOps set to use sgl-kernel in torch.compile mode")
# Compile
logger.info("\n⚙️ Compiling model with torch.compile (backend=inductor)...")
compiled_model = torch.compile(model, backend="inductor")
# Warm up
batch_size, seq_len = 2, 16
hidden_states = torch.randn(batch_size, seq_len, 4096, device=device, dtype=dtype)
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
logger.info("🔥 Running warm-up to trigger compilation...")
_ = compiled_model(hidden_states, positions)
logger.info(" ✅ Compilation complete")
# Profile
logger.info("\n📊 Profiling with torch.profiler to check actual kernels...")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
output, residual = compiled_model(hidden_states, positions)
for event in prof.key_averages():
if event.device_type == torch.profiler.DeviceType.CUDA:
print(event.key)
test_rope_implementation(
use_llama3=True,
label="Llama3RotaryEmbedding"
)
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
With the pr chnage, when I run
The
FixFunctionalizationPassoutput is below. But in torch profiler, the rope kernel is always torch-compiletritonversionbut the fused_rms_norm kernel is sgl_kernel version
I don't know why rope kernel can't be replaced by sgl-kernel rope in piece-wise Cuda Graph