Skip to content

Conversation

@BBuf
Copy link
Collaborator

@BBuf BBuf commented Oct 16, 2025

With the pr chnage, when I run

python3 -m sglang.launch_server --model meta-llama/Llama-3.3-70B-Instruct --tp 8 --port 30000 --enable-piecewise-cuda-graph --piecewise-cuda-graph-max-tokens 8192

The FixFunctionalizationPass output is below. But in torch profiler, the rope kernel is always torch-compile triton version

图片

but the fused_rms_norm kernel is sgl_kernel version

图片

I don't know why rope kernel can't be replaced by sgl-kernel rope in piece-wise Cuda Graph

/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-10-16 09:05:54] server_args=ServerArgs(model_path='meta-llama/Llama-3.3-70B-Instruct', tokenizer_path='meta-llama/Llama-3.3-70B-Instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, mem_fraction_static=0.897, max_running_requests=None, max_queued_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', elastic_ep_backend=None, mooncake_ib_device=None, tp_size=8, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=925905738, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, log_level='info', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, crash_on_nan=False, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, gc_warning_threshold_secs=0.0, enable_trace=False, oltp_traces_endpoint='localhost:4317', api_key=None, served_model_name='meta-llama/Llama-3.3-70B-Instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='triton', max_lora_chunk_size=16, attention_backend=None, decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill='flashmla_prefill', nsa_decode='fa3', enable_beta_spec=False, speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=512, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=True, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=8192, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096, 4352, 4608, 4864, 5120, 5376, 5632, 5888, 6144, 6400, 6656, 6912, 7168, 7424, 7680, 7936, 8192], torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8)
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
[2025-10-16 09:05:55] Using default HuggingFace chat template with detected content format: string
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-10-16 09:06:03 TP0] Init torch distributed begin.
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-10-16 09:06:04 TP5] Init torch distributed begin.
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
[2025-10-16 09:06:05 TP4] Init torch distributed begin.
[2025-10-16 09:06:05 TP2] Init torch distributed begin.
[2025-10-16 09:06:05 TP6] Init torch distributed begin.
[2025-10-16 09:06:05 TP1] Init torch distributed begin.
[2025-10-16 09:06:06 TP7] Init torch distributed begin.
[2025-10-16 09:06:06 TP3] Init torch distributed begin.
[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[2025-10-16 09:06:06 TP0] sglang is using nccl==2.27.3
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 2 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 6 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[2025-10-16 09:06:09 TP7] Init torch distributed ends. mem usage=1.77 GB
[2025-10-16 09:06:09 TP6] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP0] Init torch distributed ends. mem usage=1.96 GB
[2025-10-16 09:06:09 TP5] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP2] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP4] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP3] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP1] Init torch distributed ends. mem usage=2.01 GB
[2025-10-16 09:06:09 TP0] MOE_RUNNER_BACKEND is not initialized, the backend will be automatically selected
[2025-10-16 09:06:10 TP5] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP4] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP6] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP2] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP0] Load weight begin. avail mem=137.34 GB
[2025-10-16 09:06:10 TP7] Load weight begin. avail mem=137.52 GB
[2025-10-16 09:06:10 TP3] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP1] Load weight begin. avail mem=137.29 GB
[2025-10-16 09:06:10 TP0] Using model weights format ['*.safetensors']

Loading safetensors checkpoint shards:   0% Completed | 0/30 [00:00<?, ?it/s]

Loading safetensors checkpoint shards:   3% Completed | 1/30 [00:00<00:08,  3.41it/s]

Loading safetensors checkpoint shards:   7% Completed | 2/30 [00:00<00:08,  3.43it/s]

Loading safetensors checkpoint shards:  10% Completed | 3/30 [00:00<00:07,  3.50it/s]

Loading safetensors checkpoint shards:  13% Completed | 4/30 [00:01<00:07,  3.33it/s]

Loading safetensors checkpoint shards:  17% Completed | 5/30 [00:01<00:07,  3.22it/s]

Loading safetensors checkpoint shards:  20% Completed | 6/30 [00:01<00:07,  3.20it/s]

Loading safetensors checkpoint shards:  23% Completed | 7/30 [00:02<00:07,  3.20it/s]

Loading safetensors checkpoint shards:  27% Completed | 8/30 [00:02<00:06,  3.23it/s]

Loading safetensors checkpoint shards:  30% Completed | 9/30 [00:02<00:06,  3.40it/s]

Loading safetensors checkpoint shards:  33% Completed | 10/30 [00:03<00:05,  3.37it/s]

Loading safetensors checkpoint shards:  37% Completed | 11/30 [00:03<00:05,  3.29it/s]

Loading safetensors checkpoint shards:  40% Completed | 12/30 [00:03<00:05,  3.29it/s]

Loading safetensors checkpoint shards:  43% Completed | 13/30 [00:03<00:05,  3.37it/s]

Loading safetensors checkpoint shards:  47% Completed | 14/30 [00:04<00:04,  3.28it/s]

Loading safetensors checkpoint shards:  50% Completed | 15/30 [00:04<00:04,  3.20it/s]

Loading safetensors checkpoint shards:  53% Completed | 16/30 [00:04<00:03,  3.98it/s]

Loading safetensors checkpoint shards:  57% Completed | 17/30 [00:04<00:03,  3.85it/s]

Loading safetensors checkpoint shards:  60% Completed | 18/30 [00:05<00:03,  3.58it/s]

Loading safetensors checkpoint shards:  63% Completed | 19/30 [00:05<00:03,  3.42it/s]

Loading safetensors checkpoint shards:  67% Completed | 20/30 [00:05<00:03,  3.31it/s]

Loading safetensors checkpoint shards:  70% Completed | 21/30 [00:06<00:02,  3.22it/s]

Loading safetensors checkpoint shards:  73% Completed | 22/30 [00:06<00:02,  3.15it/s]

Loading safetensors checkpoint shards:  77% Completed | 23/30 [00:06<00:02,  3.09it/s]

Loading safetensors checkpoint shards:  80% Completed | 24/30 [00:07<00:01,  3.36it/s]

Loading safetensors checkpoint shards:  83% Completed | 25/30 [00:07<00:01,  3.33it/s]

Loading safetensors checkpoint shards:  87% Completed | 26/30 [00:07<00:01,  3.34it/s]

Loading safetensors checkpoint shards:  90% Completed | 27/30 [00:08<00:00,  3.27it/s]

Loading safetensors checkpoint shards:  93% Completed | 28/30 [00:08<00:00,  3.14it/s]
[2025-10-16 09:06:20 TP5] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.
[2025-10-16 09:06:20 TP7] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=121.00 GB, mem usage=16.52 GB.
[2025-10-16 09:06:20 TP6] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.
[2025-10-16 09:06:20 TP2] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.

Loading safetensors checkpoint shards:  97% Completed | 29/30 [00:08<00:00,  3.09it/s]
[2025-10-16 09:06:20 TP1] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.

Loading safetensors checkpoint shards: 100% Completed | 30/30 [00:09<00:00,  3.10it/s]

Loading safetensors checkpoint shards: 100% Completed | 30/30 [00:09<00:00,  3.30it/s]

[2025-10-16 09:06:21 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.82 GB, mem usage=16.52 GB.
[2025-10-16 09:06:21 TP3] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.
[2025-10-16 09:06:21 TP4] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=120.77 GB, mem usage=16.52 GB.
[2025-10-16 09:06:21 TP0] Using KV cache dtype: torch.bfloat16
[2025-10-16 09:06:21 TP2] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:21 TP2] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:21 TP0] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:21 TP0] Memory pool end. avail mem=11.92 GB
[2025-10-16 09:06:21 TP4] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:21 TP4] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:21 TP7] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:21 TP7] Memory pool end. avail mem=12.10 GB
[2025-10-16 09:06:22 TP5] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:22 TP1] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:22 TP3] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:22 TP5] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:22 TP6] KV Cache is allocated. #tokens: 2795207, K size: 53.31 GB, V size: 53.31 GB
[2025-10-16 09:06:22 TP1] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:22 TP3] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:22 TP6] Memory pool end. avail mem=11.87 GB
[2025-10-16 09:06:22 TP7] Capture cuda graph begin. This can take up to several minutes. avail mem=12.01 GB
[2025-10-16 09:06:22 TP2] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB
[2025-10-16 09:06:22 TP0] Capture cuda graph begin. This can take up to several minutes. avail mem=11.82 GB
[2025-10-16 09:06:22 TP0] Capture cuda graph bs [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496, 512]
[2025-10-16 09:06:22 TP4] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB
[2025-10-16 09:06:22 TP5] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB
[2025-10-16 09:06:22 TP1] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB
[2025-10-16 09:06:22 TP3] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB
[2025-10-16 09:06:22 TP6] Capture cuda graph begin. This can take up to several minutes. avail mem=11.78 GB

  0%|          | 0/52 [00:00<?, ?it/s]
Capturing batches (bs=512 avail_mem=11.08 GB):   0%|          | 0/52 [00:00<?, ?it/s]
Capturing batches (bs=512 avail_mem=11.08 GB):   2%|| 1/52 [00:00<00:46,  1.09it/s]
Capturing batches (bs=496 avail_mem=10.58 GB):   2%|| 1/52 [00:00<00:46,  1.09it/s]
Capturing batches (bs=496 avail_mem=10.58 GB):   4%|| 2/52 [00:01<00:25,  1.96it/s]
Capturing batches (bs=480 avail_mem=10.56 GB):   4%|| 2/52 [00:01<00:25,  1.96it/s]
Capturing batches (bs=480 avail_mem=10.56 GB):   6%|| 3/52 [00:01<00:16,  2.91it/s]
Capturing batches (bs=464 avail_mem=10.55 GB):   6%|| 3/52 [00:01<00:16,  2.91it/s]
Capturing batches (bs=464 avail_mem=10.55 GB):   8%|| 4/52 [00:01<00:12,  3.86it/s]
Capturing batches (bs=448 avail_mem=10.54 GB):   8%|| 4/52 [00:01<00:12,  3.86it/s]
Capturing batches (bs=448 avail_mem=10.54 GB):  10%|| 5/52 [00:01<00:10,  4.54it/s]
Capturing batches (bs=432 avail_mem=10.52 GB):  10%|| 5/52 [00:01<00:10,  4.54it/s]
Capturing batches (bs=432 avail_mem=10.52 GB):  12%|█▏        | 6/52 [00:01<00:08,  5.17it/s]
Capturing batches (bs=416 avail_mem=10.51 GB):  12%|█▏        | 6/52 [00:01<00:08,  5.17it/s]
Capturing batches (bs=416 avail_mem=10.51 GB):  13%|█▎        | 7/52 [00:01<00:07,  5.76it/s]
Capturing batches (bs=400 avail_mem=10.49 GB):  13%|█▎        | 7/52 [00:01<00:07,  5.76it/s]
Capturing batches (bs=400 avail_mem=10.49 GB):  15%|█▌        | 8/52 [00:01<00:07,  6.21it/s]
Capturing batches (bs=384 avail_mem=10.48 GB):  15%|█▌        | 8/52 [00:01<00:07,  6.21it/s]
Capturing batches (bs=384 avail_mem=10.48 GB):  17%|█▋        | 9/52 [00:02<00:06,  6.52it/s]
Capturing batches (bs=368 avail_mem=10.47 GB):  17%|█▋        | 9/52 [00:02<00:06,  6.52it/s]
Capturing batches (bs=368 avail_mem=10.47 GB):  19%|█▉        | 10/52 [00:02<00:06,  6.89it/s]
Capturing batches (bs=352 avail_mem=10.46 GB):  19%|█▉        | 10/52 [00:02<00:06,  6.89it/s]
Capturing batches (bs=352 avail_mem=10.46 GB):  21%|██        | 11/52 [00:02<00:05,  7.10it/s]
Capturing batches (bs=336 avail_mem=10.44 GB):  21%|██        | 11/52 [00:02<00:05,  7.10it/s]
Capturing batches (bs=336 avail_mem=10.44 GB):  23%|██▎       | 12/52 [00:02<00:05,  7.14it/s]
Capturing batches (bs=320 avail_mem=10.43 GB):  23%|██▎       | 12/52 [00:02<00:05,  7.14it/s]
Capturing batches (bs=320 avail_mem=10.43 GB):  25%|██▌       | 13/52 [00:02<00:05,  7.28it/s]
Capturing batches (bs=304 avail_mem=10.41 GB):  25%|██▌       | 13/52 [00:02<00:05,  7.28it/s]
Capturing batches (bs=304 avail_mem=10.41 GB):  27%|██▋       | 14/52 [00:02<00:05,  7.41it/s]
Capturing batches (bs=288 avail_mem=10.40 GB):  27%|██▋       | 14/52 [00:02<00:05,  7.41it/s]
Capturing batches (bs=288 avail_mem=10.40 GB):  29%|██▉       | 15/52 [00:02<00:05,  7.36it/s]
Capturing batches (bs=272 avail_mem=10.39 GB):  29%|██▉       | 15/52 [00:02<00:05,  7.36it/s]
Capturing batches (bs=272 avail_mem=10.39 GB):  31%|███       | 16/52 [00:03<00:05,  7.09it/s]
Capturing batches (bs=256 avail_mem=10.37 GB):  31%|███       | 16/52 [00:03<00:05,  7.09it/s]
Capturing batches (bs=256 avail_mem=10.37 GB):  33%|███▎      | 17/52 [00:03<00:04,  7.29it/s]
Capturing batches (bs=248 avail_mem=10.36 GB):  33%|███▎      | 17/52 [00:03<00:04,  7.29it/s]
Capturing batches (bs=248 avail_mem=10.36 GB):  35%|███▍      | 18/52 [00:03<00:04,  7.23it/s]
Capturing batches (bs=240 avail_mem=10.34 GB):  35%|███▍      | 18/52 [00:03<00:04,  7.23it/s]
Capturing batches (bs=240 avail_mem=10.34 GB):  37%|███▋      | 19/52 [00:03<00:04,  7.27it/s]
Capturing batches (bs=232 avail_mem=10.33 GB):  37%|███▋      | 19/52 [00:03<00:04,  7.27it/s]
Capturing batches (bs=232 avail_mem=10.33 GB):  38%|███▊      | 20/52 [00:03<00:04,  7.42it/s]
Capturing batches (bs=224 avail_mem=10.32 GB):  38%|███▊      | 20/52 [00:03<00:04,  7.42it/s]
Capturing batches (bs=224 avail_mem=10.32 GB):  40%|████      | 21/52 [00:03<00:04,  7.40it/s]
Capturing batches (bs=216 avail_mem=10.30 GB):  40%|████      | 21/52 [00:03<00:04,  7.40it/s]
Capturing batches (bs=216 avail_mem=10.30 GB):  42%|████▏     | 22/52 [00:03<00:04,  7.39it/s]
Capturing batches (bs=208 avail_mem=10.29 GB):  42%|████▏     | 22/52 [00:03<00:04,  7.39it/s]
Capturing batches (bs=208 avail_mem=10.29 GB):  44%|████▍     | 23/52 [00:04<00:03,  7.31it/s]
Capturing batches (bs=200 avail_mem=10.27 GB):  44%|████▍     | 23/52 [00:04<00:03,  7.31it/s]
Capturing batches (bs=200 avail_mem=10.27 GB):  46%|████▌     | 24/52 [00:04<00:03,  7.40it/s]
Capturing batches (bs=192 avail_mem=10.26 GB):  46%|████▌     | 24/52 [00:04<00:03,  7.40it/s]
Capturing batches (bs=192 avail_mem=10.26 GB):  48%|████▊     | 25/52 [00:04<00:03,  7.39it/s]
Capturing batches (bs=184 avail_mem=10.24 GB):  48%|████▊     | 25/52 [00:04<00:03,  7.39it/s]
Capturing batches (bs=184 avail_mem=10.24 GB):  50%|█████     | 26/52 [00:04<00:03,  7.43it/s]
Capturing batches (bs=176 avail_mem=10.23 GB):  50%|█████     | 26/52 [00:04<00:03,  7.43it/s]
Capturing batches (bs=176 avail_mem=10.23 GB):  52%|█████▏    | 27/52 [00:04<00:03,  7.32it/s]
Capturing batches (bs=168 avail_mem=10.21 GB):  52%|█████▏    | 27/52 [00:04<00:03,  7.32it/s]
Capturing batches (bs=168 avail_mem=10.21 GB):  54%|█████▍    | 28/52 [00:04<00:03,  7.41it/s]
Capturing batches (bs=160 avail_mem=10.20 GB):  54%|█████▍    | 28/52 [00:04<00:03,  7.41it/s]
Capturing batches (bs=160 avail_mem=10.20 GB):  56%|█████▌    | 29/52 [00:04<00:03,  7.51it/s]
Capturing batches (bs=152 avail_mem=10.19 GB):  56%|█████▌    | 29/52 [00:04<00:03,  7.51it/s]
Capturing batches (bs=152 avail_mem=10.19 GB):  58%|█████▊    | 30/52 [00:04<00:02,  7.48it/s]
Capturing batches (bs=144 avail_mem=10.17 GB):  58%|█████▊    | 30/52 [00:04<00:02,  7.48it/s]
Capturing batches (bs=144 avail_mem=10.17 GB):  60%|█████▉    | 31/52 [00:05<00:02,  7.49it/s]
Capturing batches (bs=136 avail_mem=10.16 GB):  60%|█████▉    | 31/52 [00:05<00:02,  7.49it/s]
Capturing batches (bs=136 avail_mem=10.16 GB):  62%|██████▏   | 32/52 [00:05<00:02,  7.47it/s]
Capturing batches (bs=128 avail_mem=10.14 GB):  62%|██████▏   | 32/52 [00:05<00:02,  7.47it/s]
Capturing batches (bs=128 avail_mem=10.14 GB):  63%|██████▎   | 33/52 [00:05<00:02,  7.51it/s]
Capturing batches (bs=120 avail_mem=10.13 GB):  63%|██████▎   | 33/52 [00:05<00:02,  7.51it/s]
Capturing batches (bs=120 avail_mem=10.13 GB):  65%|██████▌   | 34/52 [00:05<00:02,  7.41it/s]
Capturing batches (bs=112 avail_mem=10.12 GB):  65%|██████▌   | 34/52 [00:05<00:02,  7.41it/s]
Capturing batches (bs=112 avail_mem=10.12 GB):  67%|██████▋   | 35/52 [00:05<00:02,  7.28it/s]
Capturing batches (bs=104 avail_mem=10.10 GB):  67%|██████▋   | 35/52 [00:05<00:02,  7.28it/s]
Capturing batches (bs=104 avail_mem=10.10 GB):  69%|██████▉   | 36/52 [00:05<00:02,  7.28it/s]
Capturing batches (bs=96 avail_mem=10.08 GB):  69%|██████▉   | 36/52 [00:05<00:02,  7.28it/s] 
Capturing batches (bs=96 avail_mem=10.08 GB):  71%|███████   | 37/52 [00:05<00:02,  7.29it/s]
Capturing batches (bs=88 avail_mem=10.07 GB):  71%|███████   | 37/52 [00:05<00:02,  7.29it/s]
Capturing batches (bs=88 avail_mem=10.07 GB):  73%|███████▎  | 38/52 [00:06<00:01,  7.26it/s]
Capturing batches (bs=80 avail_mem=10.05 GB):  73%|███████▎  | 38/52 [00:06<00:01,  7.26it/s]
Capturing batches (bs=80 avail_mem=10.05 GB):  75%|███████▌  | 39/52 [00:06<00:01,  6.90it/s]
Capturing batches (bs=72 avail_mem=10.04 GB):  75%|███████▌  | 39/52 [00:06<00:01,  6.90it/s]
Capturing batches (bs=72 avail_mem=10.04 GB):  77%|███████▋  | 40/52 [00:06<00:01,  6.81it/s]
Capturing batches (bs=64 avail_mem=10.02 GB):  77%|███████▋  | 40/52 [00:06<00:01,  6.81it/s]
Capturing batches (bs=64 avail_mem=10.02 GB):  79%|███████▉  | 41/52 [00:06<00:01,  6.83it/s]
Capturing batches (bs=56 avail_mem=10.01 GB):  79%|███████▉  | 41/52 [00:06<00:01,  6.83it/s]
Capturing batches (bs=56 avail_mem=10.01 GB):  81%|████████  | 42/52 [00:06<00:01,  7.00it/s]
Capturing batches (bs=48 avail_mem=9.99 GB):  81%|████████  | 42/52 [00:06<00:01,  7.00it/s] 
Capturing batches (bs=48 avail_mem=9.99 GB):  83%|████████▎ | 43/52 [00:06<00:01,  6.92it/s]
Capturing batches (bs=40 avail_mem=9.98 GB):  83%|████████▎ | 43/52 [00:06<00:01,  6.92it/s]
Capturing batches (bs=40 avail_mem=9.98 GB):  85%|████████▍ | 44/52 [00:06<00:01,  6.94it/s]
Capturing batches (bs=32 avail_mem=9.96 GB):  85%|████████▍ | 44/52 [00:06<00:01,  6.94it/s]
Capturing batches (bs=32 avail_mem=9.96 GB):  87%|████████▋ | 45/52 [00:07<00:00,  7.08it/s]
Capturing batches (bs=24 avail_mem=9.95 GB):  87%|████████▋ | 45/52 [00:07<00:00,  7.08it/s]
Capturing batches (bs=24 avail_mem=9.95 GB):  88%|████████▊ | 46/52 [00:07<00:00,  7.18it/s]
Capturing batches (bs=16 avail_mem=9.93 GB):  88%|████████▊ | 46/52 [00:07<00:00,  7.18it/s]
Capturing batches (bs=16 avail_mem=9.93 GB):  90%|█████████ | 47/52 [00:07<00:00,  7.33it/s]
Capturing batches (bs=12 avail_mem=9.92 GB):  90%|█████████ | 47/52 [00:07<00:00,  7.33it/s]
Capturing batches (bs=12 avail_mem=9.92 GB):  92%|█████████▏| 48/52 [00:07<00:00,  7.36it/s]
Capturing batches (bs=8 avail_mem=9.90 GB):  92%|█████████▏| 48/52 [00:07<00:00,  7.36it/s] 
Capturing batches (bs=8 avail_mem=9.90 GB):  94%|█████████▍| 49/52 [00:07<00:00,  7.18it/s]
Capturing batches (bs=4 avail_mem=9.89 GB):  94%|█████████▍| 49/52 [00:07<00:00,  7.18it/s]
Capturing batches (bs=4 avail_mem=9.89 GB):  96%|█████████▌| 50/52 [00:07<00:00,  7.40it/s]
Capturing batches (bs=2 avail_mem=9.87 GB):  96%|█████████▌| 50/52 [00:07<00:00,  7.40it/s]
Capturing batches (bs=2 avail_mem=9.87 GB):  98%|█████████▊| 51/52 [00:07<00:00,  7.49it/s]
Capturing batches (bs=1 avail_mem=9.86 GB):  98%|█████████▊| 51/52 [00:07<00:00,  7.49it/s]
Capturing batches (bs=1 avail_mem=9.86 GB): 100%|██████████| 52/52 [00:08<00:00,  6.50it/s]
Capturing batches (bs=1 avail_mem=9.86 GB): 100%|██████████| 52/52 [00:08<00:00,  6.45it/s]
[2025-10-16 09:06:30 TP0] Registering 8372 cuda graph addresses
[2025-10-16 09:06:31 TP1] Capture cuda graph end. Time elapsed: 9.13 s. mem usage=1.98 GB. avail mem=9.79 GB.
[2025-10-16 09:06:31 TP3] Capture cuda graph end. Time elapsed: 9.14 s. mem usage=1.98 GB. avail mem=9.79 GB.
[2025-10-16 09:06:31 TP0] Capture cuda graph end. Time elapsed: 9.17 s. mem usage=1.98 GB. avail mem=9.84 GB.
[2025-10-16 09:06:31 TP0] Capture cuda graph num tokens [4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096, 4352, 4608, 4864, 5120, 5376, 5632, 5888, 6144, 6400, 6656, 6912, 7168, 7424, 7680, 7936, 8192]
[2025-10-16 09:06:31 TP5] Capture cuda graph end. Time elapsed: 9.18 s. mem usage=1.98 GB. avail mem=9.79 GB.
[2025-10-16 09:06:31 TP4] Capture cuda graph end. Time elapsed: 9.22 s. mem usage=1.98 GB. avail mem=9.79 GB.
[2025-10-16 09:06:31 TP2] Capture cuda graph end. Time elapsed: 9.23 s. mem usage=1.98 GB. avail mem=9.79 GB.
[2025-10-16 09:06:31 TP7] Capture cuda graph end. Time elapsed: 9.24 s. mem usage=1.98 GB. avail mem=10.03 GB.
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
[2025-10-16 09:06:31 TP6] Capture cuda graph end. Time elapsed: 9.23 s. mem usage=1.98 GB. avail mem=9.79 GB.
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
[2025-10-16 09:06:38 TP6] Processing auto_functionalized node: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP6] Calling _defunctionalize_rmsnorm
[2025-10-16 09:06:38 TP6] RMSNorm node: auto_functionalized
[2025-10-16 09:06:38 TP6] RMSNorm node.args[0]: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP6] RMSNorm node kwargs: ['output', 'input', 'weight', 'eps', 'enable_pdl']
[2025-10-16 09:06:38 TP6] RMSNorm getitem_users indices: [1]
[2025-10-16 09:06:38 TP6] Processing auto_functionalized node: sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default
[2025-10-16 09:06:38 TP6] Calling _defunctionalize_rope
alias: True

====================================================================================================
🔍 BEFORE FIX FUNCTIONALIZATION - FX Graph:
====================================================================================================
graph():
    %arg0_1 : [num_users=5] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=8] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %arg3_1 : [num_users=1] = placeholder[target=arg3_1]
    %arg4_1 : [num_users=1] = placeholder[target=arg4_1]
    %arg5_1 : [num_users=1] = placeholder[target=arg5_1]
    %arg6_1 : [num_users=0] = placeholder[target=arg6_1]
    %arg7_1 : [num_users=1] = placeholder[target=arg7_1]
    %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([%arg1_1, 8192],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})
    %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%arg0_1, 96192), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%arg0_1, 112224), kwargs = {})
    %bitwise_and : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %lt), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%arg0_1, 128256), kwargs = {})
    %lt_1 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%arg0_1, 128256), kwargs = {})
    %bitwise_and_1 : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_1), kwargs = {})
    %bitwise_or : [num_users=2] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_and, %bitwise_and_1), kwargs = {})
    %bitwise_not : [num_users=1] = call_function[target=torch.ops.aten.bitwise_not.default](args = (%bitwise_or,), kwargs = {})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%bitwise_not, -1), kwargs = {})
    %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_and, 96192), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_and_1, 112224), kwargs = {})
    %add_16 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_2), kwargs = {})
    %sub_10 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%arg0_1, %add_16), kwargs = {})
    %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_or, %sub_10), kwargs = {})
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%arg2_1, %mul_6), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze, %full_default, %embedding), kwargs = {})
    %outplace_all_reduce : [num_users=2] = call_function[target=torch.ops.sglang.outplace_all_reduce.default](args = (%where, tp:0, ca), kwargs = {})
    %auto_functionalized : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (sgl_kernel.rmsnorm.default,), kwargs = {output: %permute, input: %outplace_all_reduce, weight: %arg3_1, eps: 1e-05, enable_pdl: True})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized, 1), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg4_1, [1, 0]), kwargs = {})
    %mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%getitem_1, %permute_1), kwargs = {})
    %split_with_sizes : [num_users=2] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%mm, [1024, 128, 128], -1), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%split_with_sizes, 0), kwargs = {})
    %getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%split_with_sizes, 1), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [%arg1_1, -1, 128]), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [%arg1_1, -1, 128]), kwargs = {})
    %view_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [%arg1_1, -1, 128]), kwargs = {})
    %view_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [%arg1_1, -1, 128]), kwargs = {})
    %auto_functionalized_1 : [num_users=2] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default,), kwargs = {q: %view, k: %view_1, q_rope: %view_2, k_rope: %view_3, cos_sin_cache: %arg7_1, pos_ids: %arg5_1, interleave: False, enable_pdl: False, cuda_stream: 0, v: None, k_buffer: None, v_buffer: None, kv_cache_loc: None})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 1), kwargs = {})
    %getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 2), kwargs = {})
    %view_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_6, [%arg1_1, 1024]), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%mm, %view_4, 1, 0, 1024), kwargs = {})
    %view_6 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_7, [%arg1_1, 128]), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_scatter, %view_6, 1, 1024, 1152), kwargs = {})
    %split_with_sizes_5 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_5, 0), kwargs = {})
    %split_with_sizes_6 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_26 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_6, 1), kwargs = {})
    %split_with_sizes_7 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_30 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_7, 2), kwargs = {})
    %view_10 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_26, [-1, 1, 128]), kwargs = {})
    %view_11 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_30, [-1, 1, 128]), kwargs = {})
    %empty_1 : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([%arg1_1, 1024],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:6, pin_memory: False})
    %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty_1, [0, 1]), kwargs = {})
    return (getitem_22, view_10, view_11, permute_2, outplace_all_reduce)

📊 Graph Statistics:
  Total nodes: 57
  Auto-functionalized nodes: 2
  Auto-functionalized node details:
    1. auto_functionalized: auto_functionalized
       First arg: sgl_kernel.rmsnorm.default
    2. auto_functionalized_1: auto_functionalized
       First arg: sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default
====================================================================================================


====================================================================================================
✅ AFTER FIX FUNCTIONALIZATION - FX Graph:[2025-10-16 09:06:38 TP5] Processing auto_functionalized node: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP5] Calling _defunctionalize_rmsnorm
[2025-10-16 09:06:38 TP5] RMSNorm node: auto_functionalized
[2025-10-16 09:06:38 TP5] RMSNorm node.args[0]: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP5] RMSNorm node kwargs: ['output', 'input', 'weight', 'eps', 'enable_pdl']
[2025-10-16 09:06:38 TP5] RMSNorm getitem_users indices: [1]
[2025-10-16 09:06:38 TP5] Processing auto_functionalized node: sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default
[2025-10-16 09:06:38 TP5] Calling _defunctionalize_rope
alias: True

====================================================================================================
🔍 BEFORE FIX FUNCTIONALIZATION - FX Graph:
====================================================================================================
graph():
    %arg0_1 : [num_users=5] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=8] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %arg3_1 : [num_users=1] = placeholder[target=arg3_1]
    %arg4_1 : [num_users=1] = placeholder[target=arg4_1]
    %arg5_1 : [num_users=1] = placeholder[target=arg5_1]
    %arg6_1 : [num_users=0] = placeholder[target=arg6_1]
    %arg7_1 : [num_users=1] = placeholder[target=arg7_1]
    %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([%arg1_1, 8192],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False})
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})
    %ge : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%arg0_1, 80160), kwargs = {})
    %lt : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%arg0_1, 96192), kwargs = {})
    %bitwise_and : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge, %lt), kwargs = {})
    %ge_1 : [num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%arg0_1, 128256), kwargs = {})
    %lt_1 : [num_users=1] = call_function[target=torch.ops.aten.lt.Scalar](args = (%arg0_1, 128256), kwargs = {})
    %bitwise_and_1 : [num_users=2] = call_function[target=torch.ops.aten.bitwise_and.Tensor](args = (%ge_1, %lt_1), kwargs = {})
    %bitwise_or : [num_users=2] = call_function[target=torch.ops.aten.bitwise_or.Tensor](args = (%bitwise_and, %bitwise_and_1), kwargs = {})
    %bitwise_not : [num_users=1] = call_function[target=torch.ops.aten.bitwise_not.default](args = (%bitwise_or,), kwargs = {})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%bitwise_not, -1), kwargs = {})
    %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_and, 80160), kwargs = {})
    %mul_2 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_and_1, 112224), kwargs = {})
    %add_16 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_2), kwargs = {})
    %sub_10 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%arg0_1, %add_16), kwargs = {})
    %mul_6 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%bitwise_or, %sub_10), kwargs = {})
    %embedding : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%arg2_1, %mul_6), kwargs = {})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze, %full_default, %embedding), kwargs = {})
    %outplace_all_reduce : [num_users=2] = call_function[target=torch.ops.sglang.outplace_all_reduce.default](args = (%where, tp:0, ca), kwargs = {})
    %auto_functionalized : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (sgl_kernel.rmsnorm.default,), kwargs = {output: %permute, input: %outplace_all_reduce, weight: %arg3_1, eps: 1e-05, enable_pdl: True})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized, 1), kwargs = {})
    %permute_1 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg4_1, [1, 0]), kwargs = {})
    %mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%getitem_1, %permute_1), kwargs = {})
    %split_with_sizes : [num_users=2] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%mm, [1024, 128, 128], -1), kwargs = {})
    %getitem_2 : [num_users=2] = call_function[target=operator.getitem](args = (%split_with_sizes, 0), kwargs = {})
    %getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%split_with_sizes, 1), kwargs = {})
    %view : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [%arg1_1, -1, 128]), kwargs = {})
    %view_1 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [%arg1_1, -1, 128]), kwargs = {})
    %view_2 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_2, [%arg1_1, -1, 128]), kwargs = {})
    %view_3 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_3, [%arg1_1, -1, 128]), kwargs = {})
    %auto_functionalized_1 : [num_users=2] = call_function[target=torch.ops.higher_order.auto_functionalized](args = (sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default,), kwargs = {q: %view, k: %view_1, q_rope: %view_2, k_rope: %view_3, cos_sin_cache: %arg7_1, pos_ids: %arg5_1, interleave: False, enable_pdl: False, cuda_stream: 0, v: None, k_buffer: None, v_buffer: None, kv_cache_loc: None})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 1), kwargs = {})
    %getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_1, 2), kwargs = {})
    %view_4 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_6, [%arg1_1, 1024]), kwargs = {})
    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%mm, %view_4, 1, 0, 1024), kwargs = {})
    %view_6 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_7, [%arg1_1, 128]), kwargs = {})
    %slice_scatter_1 : [num_users=3] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%slice_scatter, %view_6, 1, 1024, 1152), kwargs = {})
    %split_with_sizes_5 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_5, 0), kwargs = {})
    %split_with_sizes_6 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_26 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_6, 1), kwargs = {})
    %split_with_sizes_7 : [num_users=1] = call_function[target=torch.ops.aten.split_with_sizes.default](args = (%slice_scatter_1, [1024, 128, 128], -1), kwargs = {})
    %getitem_30 : [num_users=1] = call_function[target=operator.getitem](args = (%split_with_sizes_7, 2), kwargs = {})
    %view_10 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_26, [-1, 1, 128]), kwargs = {})
    %view_11 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%getitem_30, [-1, 1, 128]), kwargs = {})
    %empty_1 : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([%arg1_1, 1024],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:5, pin_memory: False})
    %permute_2 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty_1, [0, 1]), kwargs = {})
    return (getitem_22, view_10, view_11, permute_2, outplace_all_reduce)

📊 Graph Statistics:
  Total nodes: 57
  Auto-functionalized nodes: 2
  Auto-functionalized node details:
    1. auto_functionalized: auto_functionalized
       First arg: sgl_kernel.rmsnorm.default
    2. auto_functionalized_1: auto_functionalized
       First arg: sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default
====================================================================================================


====================================================================================================
✅ AFTER FIX FUNCTIONALIZATION - FX Graph:[2025-10-16 09:06:38 TP4] Processing auto_functionalized node: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP4] Calling _defunctionalize_rmsnorm
[2025-10-16 09:06:38 TP4] RMSNorm node: auto_functionalized
[2025-10-16 09:06:38 TP4] RMSNorm node.args[0]: sgl_kernel.rmsnorm.default
[2025-10-16 09:06:38 TP4] RMSNorm node kwargs: ['output', 'input', 'weight', 'eps', 'enable_pdl']
[2025-10-16 09:06:38 TP4] RMSNorm getitem_users indices: [1]
[2025-10-16 09:06:38 TP4] Processing auto_functionalized node: sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default
[2025-10-16 09:06:38 TP4] Calling _defunctionalize_rope
alias: True

....

@BBuf
Copy link
Collaborator Author

BBuf commented Oct 17, 2025

I can't reproduce it in a single script:

Output:

bbuf python3 /home/bbuf/reproduce_rope_triton_issue.py
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
INFO:__main__:✅ sgl_kernel is available
INFO:__main__:
📝 Entering torch.compile mode...
INFO:__main__:   ✅ All CustomOps set to use sgl-kernel in torch.compile mode
INFO:__main__:
⚙️  Compiling model with torch.compile (backend=inductor)...
INFO:__main__:🔥 Running warm-up to trigger compilation...
/usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1575: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
INFO:__main__:   ✅ Compilation complete
INFO:__main__:
📊 Profiling with torch.profiler to check actual kernels...
void flashinfer::norm::RMSNormKernel<8u, __nv_bfloat16>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, unsigned int, unsigned int, unsigned int, float, float)
nvjet_tst_128x32_64x10_4x1_v_bz_TNT
triton_poi_fused__to_copy_0
triton_poi_fused__to_copy_1
void flashinfer::BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<false, 128u, 8u, 16u, __nv_bfloat16, long>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, float*, long*, unsigned int, unsigned int, unsigned int, unsigned int, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long)
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<128, 128, 64, 4, cutlass::bfloat16_t> >, false, false, false, false, false, true, false, false>(pytorch_flash::Flash_fwd_params)
void flashinfer::norm::FusedAddRMSNormKernel<8u, __nv_bfloat16>(__nv_bfloat16*, __nv_bfloat16*, __nv_bfloat16*, unsigned int, unsigned int, unsigned int, float, float)
import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging

logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(name)s:%(message)s')
logger = logging.getLogger(__name__)

# Check sgl_kernel availability
try:
    from sgl_kernel import rmsnorm, fused_add_rmsnorm, apply_rope_with_cos_sin_cache_inplace
    HAS_SGL_KERNEL = True
    logger.info("✅ sgl_kernel is available")
except ImportError:
    HAS_SGL_KERNEL = False
    logger.error("❌ sgl_kernel not available - this script requires sgl_kernel to run")
    exit(1)


# ============================================================================
# Simplified CustomOp
# ============================================================================

class CustomOp(nn.Module):
    def __init__(self):
        super().__init__()
        self._forward_method = self.dispatch_forward()

        # States for torch.compile
        self._original_forward_method = None
        self.is_torch_compile = False

    def enter_torch_compile(self, num_tokens: int):
        # Skip if Op is already entered compile mode.
        # NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
        # among layers and `enter_torch_compile` will be called many times.
        # We should prevent `self._original_forward_method` from being overridden when
        # it is not the first time `enter_torch_compile` called.
        if self.is_torch_compile:
            return

        self._original_forward_method = self._forward_method
        
        # For RMSNorm and RotaryEmbedding, keep using sgl-kernel implementations
        # instead of falling back to forward_native, to avoid performance degradation
        op_name = self.__class__.__name__
        if any(name in op_name for name in ["RMSNorm", "RotaryEmbedding", "Llama3RotaryEmbedding"]):
            # Keep the original forward method (forward_cuda with sgl-kernel)
            # Don't switch to forward_native
            pass
        else:
            self._forward_method = self.forward_native
        self.is_torch_compile = True

    def leave_torch_compile(self):
        # Skip if Op is already exited compile mode.
        if not self.is_torch_compile:
            return

        self._forward_method = self._original_forward_method
        self._original_forward_method = None
        self.is_torch_compile = False

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError
    
    def dispatch_forward(self):
        return self.forward_cuda


# ============================================================================
# RMSNorm
# ============================================================================

class RMSNorm(CustomOp):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward_cuda(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
        # Reshape to 2D for sgl_kernel
        original_shape = x.shape
        if x.ndim == 3:
            x = x.view(-1, x.shape[-1])
            if residual is not None:
                residual = residual.view(-1, residual.shape[-1])
        
        if residual is not None:
            fused_add_rmsnorm(x, residual, self.weight.data, self.eps)
            if len(original_shape) == 3:
                x = x.view(original_shape)
                residual = residual.view(original_shape)
            return x, residual
        else:
            out = rmsnorm(x, self.weight.data, self.eps)
            if len(original_shape) == 3:
                out = out.view(original_shape)
            return out

    def forward_native(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
        if residual is not None:
            x_added = x + residual
            variance = x_added.pow(2).mean(-1, keepdim=True)
            out = x_added * torch.rsqrt(variance + self.eps) * self.weight
            return out, x_added
        else:
            variance = x.pow(2).mean(-1, keepdim=True)
            out = x * torch.rsqrt(variance + self.eps) * self.weight
            return out


# ============================================================================
# RotaryEmbedding
# ============================================================================

class RotaryEmbedding(CustomOp):
    def __init__(self, head_size: int = 128, max_position: int = 4096, base: int = 10000, is_neox_style: bool = True):
        super().__init__()
        self.head_size = head_size
        self.max_position = max_position
        self.base = base
        self.is_neox_style = is_neox_style
        
        # Compute cos/sin cache (keep as float32 for sgl_kernel)
        cache = self._compute_cos_sin_cache()
        self.register_buffer("cos_sin_cache", cache.float(), persistent=False)
    
    def _compute_inv_freq(self, base):
        """Compute the inverse frequency."""
        inv_freq = 1.0 / (base ** (torch.arange(0, self.head_size, 2, dtype=torch.float) / self.head_size))
        return inv_freq
    
    def _compute_cos_sin_cache(self):
        """Compute the cos and sin cache."""
        inv_freq = self._compute_inv_freq(self.base)
        t = torch.arange(self.max_position, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward_cuda(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor):
        # Ensure cos_sin_cache is float32
        cos_sin_cache = self.cos_sin_cache
        if cos_sin_cache.dtype != torch.float32:
            cos_sin_cache = cos_sin_cache.float()
        
        # Flatten positions
        if positions.ndim > 1:
            positions = positions.reshape(-1)
        
        # Call sgl_kernel rope
        apply_rope_with_cos_sin_cache_inplace(
            positions=positions,
            query=query,
            key=key,
            head_size=self.head_size,
            cos_sin_cache=cos_sin_cache,
            is_neox=True,
        )
        return query, key

    def forward_native(self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor):
        # Simplified PyTorch implementation
        cos_sin = self.cos_sin_cache[positions]
        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = cos.repeat(1, 1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 1, 2).unsqueeze(-2)
        
        def rotate_neox(x):
            x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
        
        query_rot = query * cos + rotate_neox(query) * sin
        key_rot = key * cos + rotate_neox(key) * sin
        return query_rot, key_rot


# ============================================================================
# Llama3RotaryEmbedding (same as SGLang implementation)
# ============================================================================

class Llama3RotaryEmbedding(RotaryEmbedding):
    """Llama3 style rotary embedding with frequency scaling"""
    
    def __init__(
        self,
        head_size: int = 128,
        max_position: int = 131072,
        base: int = 500000,
        is_neox_style: bool = True,
        scaling_factor: float = 8.0,
        low_freq_factor: float = 1.0,
        high_freq_factor: float = 4.0,
        orig_max_position: int = 8192,
    ):
        self.scaling_factor = scaling_factor
        self.low_freq_factor = low_freq_factor
        self.high_freq_factor = high_freq_factor
        self.orig_max_position = orig_max_position
        super().__init__(head_size, max_position, base, is_neox_style)
    
    def _compute_inv_freq(self, base):
        """Llama3-specific inverse frequency computation with scaling"""
        import math
        
        inv_freqs = super()._compute_inv_freq(base)
        low_freq_wavelen = self.orig_max_position / self.low_freq_factor
        high_freq_wavelen = self.orig_max_position / self.high_freq_factor

        wave_len = 2 * math.pi / inv_freqs
        if self.low_freq_factor != self.high_freq_factor:
            smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
                self.high_freq_factor - self.low_freq_factor
            )
        else:
            smooth = 0
        new_freqs = torch.where(
            wave_len < high_freq_wavelen,
            inv_freqs,
            torch.where(
                wave_len > low_freq_wavelen,
                inv_freqs / self.scaling_factor,
                (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
            ),
        )
        return new_freqs


# ============================================================================
# Simple Test Model
# ============================================================================

class SimpleModel(nn.Module):
    def __init__(self, hidden_size: int = 4096, num_heads: int = 32, head_size: int = 128, use_llama3_rope: bool = False):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = head_size
        
        self.input_norm = RMSNorm(hidden_size)
        
        # Choose RoPE type
        if use_llama3_rope:
            self.rotary_emb = Llama3RotaryEmbedding(
                head_size=head_size,
                max_position=131072,
                base=500000,
                is_neox_style=True,
                scaling_factor=8.0,
                low_freq_factor=1.0,
                high_freq_factor=4.0,
                orig_max_position=8192,
            )
        else:
            self.rotary_emb = RotaryEmbedding(head_size=head_size)
        
        self.qkv_proj = nn.Linear(hidden_size, 3 * num_heads * head_size, bias=False)
        self.output_norm = RMSNorm(hidden_size)

    def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor, residual: Optional[torch.Tensor] = None):
        # Input norm
        if residual is None:
            hidden_states = self.input_norm(hidden_states)
            residual = hidden_states
        else:
            hidden_states, residual = self.input_norm(hidden_states, residual)
        
        # QKV projection
        batch_size, seq_len = hidden_states.shape[:2]
        qkv = self.qkv_proj(hidden_states)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_size)
        q, k, v = qkv.unbind(dim=2)
        
        # Apply RoPE
        q = q.reshape(batch_size * seq_len, self.num_heads, self.head_size)
        k = k.reshape(batch_size * seq_len, self.num_heads, self.head_size)
        q, k = self.rotary_emb(positions.reshape(-1), q, k)
        q = q.view(batch_size, seq_len, self.num_heads, self.head_size)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_size)
        
        # Simplified attention (just for testing)
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        ).transpose(1, 2)
        
        output = attn_output.reshape(batch_size, seq_len, -1)
        
        # Output norm
        output, residual = self.output_norm(output, residual)
        return output, residual


# ============================================================================
# Main Test
# ============================================================================

def test_rope_implementation(use_llama3: bool, label: str):
    """Test a specific RoPE implementation"""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16
    
    # Create model
    model = SimpleModel(use_llama3_rope=use_llama3).to(device, dtype)
    
    # Enter torch.compile mode
    logger.info("\n📝 Entering torch.compile mode...")
    model.input_norm.enter_torch_compile(num_tokens=16)
    model.output_norm.enter_torch_compile(num_tokens=16)
    model.rotary_emb.enter_torch_compile(num_tokens=16)
    logger.info("   ✅ All CustomOps set to use sgl-kernel in torch.compile mode")
    
    # Compile
    logger.info("\n⚙️  Compiling model with torch.compile (backend=inductor)...")
    compiled_model = torch.compile(model, backend="inductor")
    
    # Warm up
    batch_size, seq_len = 2, 16
    hidden_states = torch.randn(batch_size, seq_len, 4096, device=device, dtype=dtype)
    positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
    
    logger.info("🔥 Running warm-up to trigger compilation...")
    _ = compiled_model(hidden_states, positions)
    logger.info("   ✅ Compilation complete")
    
    # Profile
    logger.info("\n📊 Profiling with torch.profiler to check actual kernels...")
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CUDA],
        record_shapes=True
    ) as prof:
        output, residual = compiled_model(hidden_states, positions)
    
    for event in prof.key_averages():
        if event.device_type == torch.profiler.DeviceType.CUDA:
            print(event.key)
    

test_rope_implementation(
        use_llama3=True,
        label="Llama3RotaryEmbedding"
    )

@BBuf BBuf closed this Oct 24, 2025
@zhyncs zhyncs deleted the piece_cuda_graph_support_sgl_kernel branch October 24, 2025 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants