-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9457][feat] Add cute dsl fp8 gemm for Blackwell #10130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TRTLLM-9457][feat] Add cute dsl fp8 gemm for Blackwell #10130
Conversation
a00dde4 to
a37273e
Compare
cff5bf2 to
305c75d
Compare
📝 WalkthroughWalkthroughThis PR introduces FP8-accelerated CuTe DSL operations for Blackwell (SM100/103) hardware. It adds three new custom ops (GEMM, batched GEMM, and group-blockwise GEMM) with corresponding tunable runner classes, extends ModelConfig with block-scaling flags, threads these flags through attention and linear modules, and adds comprehensive test coverage with shape inference utilities. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unittest/_torch/modules/test_fused_moe.py (1)
1049-1160: Multi-GPU wrapper not updated for newuse_cute_dsl_fp8argument
test_fused_moe_fp8_blockwise_cute_dslnow expectsuse_cute_dsl_fp8beforemapping, buttest_fused_moe_fp8_blockwise_cute_dsl_multi_gpustill only passes 7 positional args. As a result:
- The
Mapping(...)instance is currently bound touse_cute_dsl_fp8.- The
mappingargument falls back to its default (None), breaking the intended TP/EP setup.use_cute_dsl_fp8receives a non-bool object, which may also affect kernel selection.This makes the multi-GPU test incorrect and potentially hides issues on the CuTe DSL FP8 path.
Consider explicitly passing a boolean before
Mapping(...)in theexecutor.mapcall, e.g. hard-codingTrueor adding a parametrize onuse_cute_dsl_fp8:- test_fused_moe_fp8_blockwise_cute_dsl, - *zip(*[( - torch.bfloat16, - 72, - 384, - 384, - routing_method, - weight_loading_mode, - Mapping( - world_size=world_size, - tp_size=world_size, - moe_ep_size=ep_size, - moe_tp_size=world_size // ep_size, - ), - )] * world_size), + test_fused_moe_fp8_blockwise_cute_dsl, + *zip(*[( + torch.bfloat16, + 72, + 384, + 384, + routing_method, + weight_loading_mode, + True, # or parametrize/use both False/True if needed + Mapping( + world_size=world_size, + tp_size=world_size, + moe_ep_size=ep_size, + moe_tp_size=world_size // ep_size, + ), + )] * world_size),Also applies to: 1351-1381
🤖 Fix all issues with AI agents
In @tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py:
- Around line 2651-2664: The validation in the registered fake kernel for
"trtllm::cute_dsl_fp8_bmm_blackwell" uses `assert False` which is stripped under
Python -O; replace those with explicit exceptions (e.g., `raise
AssertionError("out.dtype != bf16")` and `raise AssertionError("out.shape !=
(batch_size, m, n)")`) inside the anonymous function registered via
@torch.library.register_fake to ensure the checks always run.
- Around line 2721-2741: The call to self.__class__.kernel_class.can_implement
references an undefined variable batch_size; define batch_size=1 (or replace the
argument with the literal 1) before that list comprehension so the can_implement
call receives the correct batch parameter (consistent with the forward method
which uses batch=1); update the code where
self.__class__.kernel_class.can_implement(...) is invoked to use this defined
batch_size or the literal 1.
- Around line 2448-2454: The log references an undefined variable `sm_version`;
before calling logger.debug in the block that checks is_sm_100f(), retrieve the
SM version (e.g., call the existing helper used elsewhere such as
get_sm_version() or compute sm_version = get_sm_version()) and use that variable
in the f-string, or remove the sm_version interpolation entirely; update the
code around the is_sm_100f() check and the logger.debug call in the CuteDSLFp8
BMM path so `sm_version` is defined (or omitted) to avoid the NameError.
- Around line 2207-2213: The debug message references an undefined sm_version;
update the logger.debug in the early-exit for is_sm_100f() to either remove
sm_version or replace it with a real SM query. Concretely, inside the block that
calls is_sm_100f() (and returns []), change the message to not use the undefined
variable or call an existing helper that returns the SM (e.g., get_sm_version()
/ get_device_sm()) and include that result; ensure the change is made where
logger.debug is invoked for the "CuteDSL FP8 GEMM only supports SM 100 family"
check.
- Around line 2688-2694: The log references an undefined sm_version inside the
method that checks "if not is_sm_100f()": define or obtain the current SM
version before logging (e.g., call the existing helper that returns the SM
string/version or query the device capability) and use that variable in the
f-string, or remove the sm_version interpolation and log a generic message;
update the logging line in the method containing the "if not is_sm_100f()" check
so it uses a defined variable (e.g., sm_version) or does not reference it.
In @tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py:
- Around line 606-624: Align each expert group's token range to next multiple
of 128 before building group_offset: compute total_tokens = x.shape[0] *
x.shape[1], for each slot use start = expert_first_token_offset[i] and end =
expert_first_token_offset[i+1] (or total_tokens for last slot), then set
padded_end = min(ceil(end / 128) * 128, total_tokens) and assign
group_offset[start:padded_end] = i; ensure any remaining tokens up to
total_tokens are assigned (e.g., to the last group) so group_offset length stays
total_tokens; keep this logic only when use_cute_dsl_fp8 and is_sm_100f() before
calling torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell, and
reference symbols group_offset, expert_first_token_offset, fp8_quantize_1x128,
num_slots, use_cute_dsl_fp8, and cute_dsl_fp8_group_blockwise_gemm_blackwell.
In @tensorrt_llm/_torch/utils.py:
- Around line 304-310: The function fp8_scale_infer_shape is missing a
Google-style docstring and the file lacks the required NVIDIA copyright header;
add a docstring to fp8_scale_infer_shape mirroring the style and content of the
existing fp4_scale_infer_shape and fp4_unswizzled_scale_infer_shape functions
(describe params: input_shapes, return value, behavior for 2D/3D inputs and
batch handling) and prepend the file with the NVIDIA copyright header including
the year of latest meaningful modification.
🧹 Nitpick comments (4)
tensorrt_llm/_torch/model_config.py (1)
125-127: Consider more descriptive naming to clarify hardware targeting.The new configuration flags enable CuTe DSL block-scaling paths, which according to the PR are optimized for Blackwell hardware. Based on learnings, when adding hardware capability checks or configuration, descriptive names help clarify the specific GPU architectures being targeted.
Consider renaming for clarity:
use_cute_dsl_blockscaling_mm→use_cute_dsl_blockscaling_mm_blackwellorenable_blackwell_cute_dsl_mmuse_cute_dsl_blockscaling_bmm→use_cute_dsl_blockscaling_bmm_blackwellorenable_blackwell_cute_dsl_bmmThis makes it immediately clear that these flags enable Blackwell-optimized implementations, consistent with the naming pattern in
cute_dsl_fp8_gemm_blackwell.Alternative: Keep current naming if these paths may support other architectures in the future
If the CuTe DSL paths are intended to support multiple SM versions beyond Blackwell (SM100/103), the current generic naming may be appropriate. In that case, ensure the hardware gating logic (e.g.,
is_sm_100f()) in the consuming code clearly documents which architectures are supported.Based on learnings, using descriptive names that identify specific GPU architectures helps clarify hardware-capability selection.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (1)
606-650: Consider refactoring to reduce code duplication.The conditional logic for switching between the Blackwell-optimized path and the reference implementation is duplicated for both the w3_w1 GEMM (lines 617-632) and the w2 GEMM (lines 635-650). This duplication makes the code harder to maintain.
♻️ Suggested refactor to eliminate duplication
Extract the conditional GEMM logic into a helper method:
def _run_fp8_group_gemm( self, x: torch.Tensor, weight: torch.Tensor, x_sf: torch.Tensor, weight_scale: torch.Tensor, group_offset: torch.Tensor, expert_first_token_offset: torch.Tensor, weight_dtype: torch.dtype, ) -> torch.Tensor: """Run FP8 grouped GEMM with appropriate backend.""" if is_sm_100f() and self.use_cute_dsl_fp8: return torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell( input=x, weight=weight.view(weight_dtype), input_scale=x_sf, weight_scale=weight_scale, group_offset=group_offset, ) else: return cute_dsl_fp8_group_blockwise_gemm_ref( a=x, b=weight.view(weight_dtype), a_sf=x_sf, b_sf=weight_scale, offset_array=expert_first_token_offset, )Then update the main logic:
# First GEMM: w3_w1 x = self._run_fp8_group_gemm( x, self.w3_w1_weight, x_sf, self.quant_scales[0], group_offset, expert_first_token_offset, weight_dtype ) x = swiglu_fused_moe(x) x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) # Second GEMM: w2 x = self._run_fp8_group_gemm( x, self.w2_weight, x_sf, self.quant_scales[1], group_offset, expert_first_token_offset, weight_dtype )This eliminates duplication and makes it easier to add additional backends or modify the selection logic in the future.
tensorrt_llm/_torch/modules/attention.py (1)
695-735: CuTe DSL FP8 BMM path on SM100 is correctly gated and integratedThe updated
fp8_block_scaling_bmm_outand MLA call sites introduce a CuTe DSL FP8 BMM fast path on Blackwell while preserving existing behavior elsewhere:
- For SM90/89 and SM120, logic is unchanged:
mat1is quantized on the fly andtorch.ops.trtllm.fp8_block_scaling_bmm_outis used.- For SM100 (
is_sm_100f(sm_version)):
- When
use_cute_dsl_blockscaling_bmmisTrue, you:
- Quantize
mat1viafp8_batched_quantize_1x128_permute102.- Call
cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8, mat1_scale, mat2_scale, out)directly.- Avoid allocating BF16 dequant buffers.
- When
use_cute_dsl_blockscaling_bmmisFalse, you:
- In
MLA.create_weights, allocate BF16k_b_proj_trans_dequant/v_b_proj_dequantonly ifhas_fp8_block_scalesandis_sm_100f()and assertself.dtype == torch.bfloat16.- Use those dequant buffers in the fallback
torch.bmmpath:
mat1.transpose(0, 1)×mat2_dequant.transpose(1, 2)→out,- Shapes are consistent at all call sites (per-head BMM with
[num_heads, seq, k] x [num_heads, k, v]).The callers in MLA (
forward_absorption_generation,forward_absorption_context,forward_sparse_mla_kvcache_bf16) all pass:
- FP8 weights + scales (
k_b_proj_trans/v_b_proj+_scale),- The optional BF16 dequant tensors, and
- The
self.use_cute_dsl_blockscaling_bmmflag,so the helper can select the right kernel per SM and configuration without extra branching at the call sites.
Only very small nit:
mat1_scale = Noneafter the CuTe DSL call infp8_block_scaling_bmm_outis redundant and can be dropped, but it is harmless.Also applies to: 1076-1134, 1901-1909, 1980-1987, 2036-2043, 2092-2099, 2160-2167, 2237-2244
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
2187-2196: Consider annotating mutable class attributes withClassVar.The
kernel_cacheclass attributes inCuteDSLFp8BlackwellLinear,CuteDSLFp8BlackwellBmm, andCuteDSLFp8BlackwellGroupGemmare mutable dictionaries shared across instances. Annotating withClassVarmakes this intent explicit.♻️ Example for CuteDSLFp8BlackwellLinear
+from typing import ClassVar, Dict, Any + class CuteDSLFp8BlackwellLinear(TunableRunner): kernel_class = BlockwiseGemmKernel - kernel_cache = dict() + kernel_cache: ClassVar[Dict[Any, Any]] = {}
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.pytensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/contiguous_grouped_gemm.pytensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/attention.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/utils.pytests/unittest/_torch/modules/test_fused_moe.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces. Do not use tabs
Always maintain the namespace when importing Python modules, even if only one class or function from a module is used
Python filenames should use snake_case (e.g.,some_file.py)
Python classes should use PascalCase (e.g.,class SomeClass)
Python functions and methods should use snake_case (e.g.,def my_awesome_function():)
Python local variables should use snake_case, with prefixkfor variable names that start with a number (e.g.,k_99th_percentile)
Python global variables should use upper snake_case with prefixG(e.g.,G_MY_GLOBAL)
Python constants should use upper snake_case (e.g.,MY_CONSTANT)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Use comments in Python for code within a function, or interfaces that are local to a file
Use Google-style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with the format"""<type>: Description"""
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of errors possible
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block for the main logic
Files:
tensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.pytensorrt_llm/_torch/utils.pytests/unittest/_torch/modules/test_fused_moe.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.pytensorrt_llm/_torch/modules/attention.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.pytensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification
Files:
tensorrt_llm/_torch/model_config.pytensorrt_llm/_torch/modules/linear.pytensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.pytensorrt_llm/_torch/utils.pytests/unittest/_torch/modules/test_fused_moe.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.pytensorrt_llm/_torch/modules/attention.pytests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.pytensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
🧠 Learnings (14)
📓 Common learnings
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
📚 Learning: 2025-09-19T21:28:13.751Z
Learnt from: jhaotingc
Repo: NVIDIA/TensorRT-LLM PR: 7856
File: cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp:159-166
Timestamp: 2025-09-19T21:28:13.751Z
Learning: In TensorRT-LLM blockScaleMoe routing (cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu), the DeepSeek routing method performs reinterpret_cast<float*>(routingLogits) at line 89, which could cause issues if routing_logits are BF16. However, Qwen3-FP8 models use RenormalizeNaive routing method and are not affected by this dtype casting issue.
Applied to files:
tensorrt_llm/_torch/modules/linear.pytests/unittest/_torch/modules/test_fused_moe.py
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tensorrt_llm/_torch/modules/linear.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/modules/linear.py
📚 Learning: 2026-01-06T03:07:15.754Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2026-01-06T03:07:15.754Z
Learning: Applies to **/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py} : All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification
Applied to files:
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py
📚 Learning: 2025-10-22T06:53:47.017Z
Learnt from: xinhe-nv
Repo: NVIDIA/TensorRT-LLM PR: 8534
File: scripts/format_test_list.py:1-6
Timestamp: 2025-10-22T06:53:47.017Z
Learning: The file `scripts/format_test_list.py` in the TensorRT-LLM repository does not require the NVIDIA Apache-2.0 copyright header.
Applied to files:
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
Applied to files:
tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-12-19T06:31:54.973Z
Learnt from: nvyocox
Repo: NVIDIA/TensorRT-LLM PR: 10117
File: tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py:336-339
Timestamp: 2025-12-19T06:31:54.973Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py, the cast to torch.float16 for qkv_node before creating the AttentionPlugin is intentional and required because DriveOS LLM expects float16 dtype specifically. This should not be changed to preserve original dtype or made configurable for bfloat16 models in the DriveOS LLM ONNX export path.
Applied to files:
tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation with asserts for total size and TP divisibility.
Applied to files:
tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.
Applied to files:
tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-08-14T15:43:23.107Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: tensorrt_llm/_torch/attention_backend/trtllm.py:259-262
Timestamp: 2025-08-14T15:43:23.107Z
Learning: In TensorRT-LLM's attention backend, tensor parameters in the plan() method are assigned directly without validation (dtype, device, contiguity checks). This maintains consistency across all tensor inputs and follows the pattern of trusting callers to provide correctly formatted tensors.
Applied to files:
tensorrt_llm/_torch/modules/attention.py
📚 Learning: 2025-11-14T11:22:03.729Z
Learnt from: nzmora-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 9163
File: tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py:107-113
Timestamp: 2025-11-14T11:22:03.729Z
Learning: In TensorRT-LLM AutoDeploy custom ops, when adding hardware capability checks to select between kernel implementations (e.g., cuBLAS vs. CUDA kernel), use descriptive variable names that identify the specific GPU architectures or families being targeted (e.g., `is_blackwell_geforce_or_ada`) rather than generic names like `enable_cuda_core`. This makes it clear that the code is selecting an implementation path based on hardware capabilities, not enabling/disabling hardware features.
Applied to files:
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
📚 Learning: 2025-08-21T21:48:35.135Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 7104
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:399-417
Timestamp: 2025-08-21T21:48:35.135Z
Learning: CUTLASS extensions in TensorRT-LLM (located under cpp/tensorrt_llm/cutlass_extensions/) are designed to integrate with and extend functionality in the external CUTLASS repository. When analyzing these extensions, their consumers and functionality wiring may exist in the CUTLASS codebase rather than within TensorRT-LLM itself.
Applied to files:
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
📚 Learning: 2025-12-12T10:07:31.564Z
Learnt from: lirundong
Repo: NVIDIA/TensorRT-LLM PR: 9725
File: tensorrt_llm/_torch/custom_ops/cuda_tile_custom_ops.py:110-178
Timestamp: 2025-12-12T10:07:31.564Z
Learning: In PyTorch custom operators registered with torch.library.custom_op, mutable operators that return None and specify mutates_args do not require a register_fake decorator. Mutation tracking is handled automatically without needing a FakeTensor kernel. This applies to Python custom op definitions in tensorrt_llm/_torch/custom_ops that use mutates_args and return None; verify you are not relying on register_fake in these cases.
Applied to files:
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (2)
tensorrt_llm/_utils.py (1)
is_sm_100f(758-761)tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
cute_dsl_fp8_group_blockwise_gemm_blackwell(2888-2906)
tests/unittest/_torch/modules/test_fused_moe.py (2)
tensorrt_llm/_torch/modules/fused_moe/routing.py (1)
DefaultMoeRoutingMethod(188-219)tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
MoEWeightLoadingMode(17-23)
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py (6)
cpp/include/tensorrt_llm/common/cudaUtils.h (1)
isSM100Family(325-329)tests/unittest/utils/util.py (2)
isSM100Family(98-100)getSMVersion(62-81)tensorrt_llm/_utils.py (2)
dtype(997-998)dtype(1005-1015)tests/unittest/_torch/helpers.py (1)
per_block_cast_to_fp8(33-46)tensorrt_llm/_torch/autotuner.py (1)
autotune(257-295)tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (2)
cute_dsl_fp8_gemm_blackwell(2394-2411)cute_dsl_fp8_bmm_blackwell(2630-2649)
tensorrt_llm/_torch/modules/attention.py (3)
cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h (1)
trtllm(31-109)tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
cute_dsl_fp8_bmm_blackwell(2630-2649)tensorrt_llm/_utils.py (1)
is_sm_100f(758-761)
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.py (2)
tests/unittest/_torch/helpers.py (2)
calc_diff(76-80)per_block_cast_to_fp8(33-46)tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
cute_dsl_fp8_group_blockwise_gemm_blackwell(2888-2906)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (1)
tensorrt_llm/_torch/utils.py (6)
fp4_scale_infer_shape(287-292)fp8_scale_infer_shape(304-309)get_last_power_of_2_num_tokens_buckets(277-284)last_positive_power_of_2(254-259)shape(141-142)_(228-234)
🪛 Ruff (0.14.10)
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.py
40-40: Loop control variable i not used within loop body
(B007)
41-41: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
2189-2189: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
2204-2204: Unused method argument: profile
(ARG002)
2205-2205: Unused method argument: kwargs
(ARG002)
2210-2210: Undefined name sm_version
(F821)
2417-2417: Unused function argument: input_scale
(ARG001)
2418-2418: Unused function argument: weight_scale
(ARG001)
2430-2430: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
2445-2445: Unused method argument: profile
(ARG002)
2446-2446: Unused method argument: kwargs
(ARG002)
2451-2451: Undefined name sm_version
(F821)
2655-2655: Unused function argument: input_scale
(ARG001)
2656-2656: Unused function argument: weight_scale
(ARG001)
2659-2659: Local variable k is assigned to but never used
Remove assignment to unused variable k
(F841)
2662-2662: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
2664-2664: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
2668-2668: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
2686-2686: Unused method argument: profile
(ARG002)
2687-2687: Unused method argument: kwargs
(ARG002)
2691-2691: Undefined name sm_version
(F821)
2698-2698: Local variable group_num is assigned to but never used
Remove assignment to unused variable group_num
(F841)
2736-2736: Undefined name batch_size
(F821)
2913-2913: Unused function argument: input_scale
(ARG001)
2914-2914: Unused function argument: weight_scale
(ARG001)
2915-2915: Unused function argument: group_offset
(ARG001)
2918-2918: Local variable num_group is assigned to but never used
Remove assignment to unused variable num_group
(F841)
2918-2918: Local variable k is assigned to but never used
Remove assignment to unused variable k
(F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (13)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (4)
311-311: LGTM! Docstring updated for new parameter.The docstring correctly documents the new
use_cute_dsl_fp8parameter.
332-332: LGTM! Parameter follows established patterns.The new
use_cute_dsl_fp8parameter is properly typed with a sensible default value for backward compatibility.
608-616: Verify group_offset computation correctness.The
group_offsettensor maps each token position to its expert index. However, please verify:
- Size calculation:
x.shape[0] * x.shape[1]assumes x is 2D after quantization. Confirm this matches the actual tensor shape at this point in the code.- Boundary condition: The line
end = expert_first_token_offset[i + 1] if i < self.num_slots - 1 else x.shape[0] * x.shape[1]handles the last expert specially. Verify this correctly captures all tokens for the final expert group.- Expert index mapping: The assignment
group_offset[start:end] = iuses the loop indexias the expert identifier. Confirm this matches the expert indexing expected by thecute_dsl_fp8_group_blockwise_gemm_blackwelloperator.Consider adding assertions to validate the group_offset computation:
🛡️ Suggested validation checks
+ # Validate group_offset construction + assert x.dim() == 2, f"Expected 2D tensor after quantization, got shape {x.shape}" + total_tokens = x.shape[0] * x.shape[1] group_offset = torch.empty(x.shape[0] * x.shape[1], dtype=torch.int32, device="cuda") for i in range(self.num_slots): start = expert_first_token_offset[i] end = expert_first_token_offset[ i + 1] if i < self.num_slots - 1 else x.shape[0] * x.shape[1] + assert start <= end <= total_tokens, f"Invalid expert boundaries: start={start}, end={end}, total={total_tokens}" group_offset[start:end] = i + # Verify all tokens are covered + assert expert_first_token_offset[-1] == total_tokens or expert_first_token_offset[self.num_slots-1] <= total_tokens, \ + "Expert offset boundaries don't cover all tokens"
617-624: group_offset reuse is correctThe concern is unfounded. Neither
swiglu_fused_moe()norfp8_quantize_1x128()changes the expert-to-token mapping:
swiglu_fused_moe()chunks on the feature dimension (dim=-1) and applies element-wise operations only; it preserves token ordering entirely.fp8_quantize_1x128()is a pure quantization operation that converts the tensor datatype and produces scale factors; it does not reorder tokens or change the logical tensor shape.The
group_offsettensor remains valid for both GEMMs.Likely an incorrect or invalid review comment.
tensorrt_llm/_torch/modules/linear.py (1)
741-747: Operator registration and alignment handling verified.The
trtllm::cute_dsl_fp8_gemm_blackwelloperator is properly registered intensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.pywith both custom op and fake implementations. The quantization path is correct:fp8_quantize_1x128produces FP8-quantized tensors in 1×128 blocks (blocking dimension K into 128-element chunks), which are then consumed bycute_dsl_fp8_gemm_blackwell. The alignment requirement (K divisible by 128) is inherent to the quantization op itself and is typically satisfied by standard LLM hidden dimensions on Blackwell architectures. No additional padding or alignment checks are needed at the call site.tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py (3)
1-1: Header year range update is correctThe SPDX copyright line now covers 2022–2026, matching the latest modification year and repository guidelines.
112-154: Cute DSL FP8 block-scale GEMM test matches existing reference patternThis Blackwell-only test:
- Uses the same input/weight setup and FP8 quantization (
fp8_quantize_1x128+per_block_cast_to_fp8) astest_fp8_block_scale_gemm.- Runs
cute_dsl_fp8_gemm_blackwellonce underautotune()and once post-tuning.- Checks both the custom
calc_diffmetric andtorch.testing.assert_closewith the same1e-3tolerances.The coverage and numerics look well-aligned with the existing non-CuTe FP8 GEMM tests.
218-267: Cute DSL FP8 block-scale BMM test is consistent with existing BMM pathThis test:
- Reuses the same shapes and FP8 preparation (
fp8_batched_quantize_1x128_permute102+per_block_cast_to_fp8) astest_fp8_block_scale_bmm.- Restricts execution to Blackwell via
isSM100Family().- Tunes
cute_dsl_fp8_bmm_blackwellonce underautotune()and then reruns the tuned kernel.- Compares against the dense
einsumreference with1e-3diff andassert_closetolerances.The wiring and tolerances look correct and should meaningfully validate the new CuTe DSL BMM path.
tensorrt_llm/_torch/modules/attention.py (1)
267-297: New CuTe DSL block-scaling flags are plumbed cleanly through Attention and MLA
AttentionandMLAnow both readconfig.use_cute_dsl_blockscaling_mm/config.use_cute_dsl_blockscaling_bmmand passuse_cute_dsl_blockscaling_mminto all relevantLinearlayers (qkv, o_proj, kv_a_proj_with_mqa, q/q_b/kv_b projections, MLA o_proj). This gives a single configuration source for switching to the CuTe DSL FP8 block-scaling MM kernels.Behavior-wise this looks sound, assuming:
ModelConfigdefines these booleans with defaultFalse, and- Any external/custom
ModelConfigconstructors are updated so these attributes always exist.Might be worth double-checking all
ModelConfigcreation sites and serialization/deserialization (e.g., config loading in training/inference scripts) to ensure the new fields are always present and defaulted, to avoidAttributeErrorin older pipelines.Also applies to: 875-981
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py (4)
8-16: LGTM!The new imports for
is_sm_100fand the shape inference utilities are correctly added.
340-343: LGTM!The new kernel imports for
BlockwiseGemmKernelandBlockwiseContiguousGroupedGemmKernelare correctly structured and follow the established import pattern.
2390-2426: LGTM!The custom op and fake registration are correctly implemented. The output shape inference correctly derives
[m, n]from the input tensors with bf16 dtype.
2883-2919: LGTM!The custom op and fake registration are correctly implemented. The output shape inference correctly derives
[m, n]from the input tensors.
tests/unittest/_torch/thop/parallel/test_fp8_block_scale_group_gemm.py
Outdated
Show resolved
Hide resolved
|
/bot run |
|
PR_Github #31547 [ run ] triggered by Bot. Commit: |
|
PR_Github #31547 [ run ] completed with state
|
|
/bot run |
|
PR_Github #31699 [ run ] triggered by Bot. Commit: |
|
PR_Github #31699 [ run ] completed with state
|
|
/bot run --stage-list "B300-PyTorch-1" |
|
PR_Github #31740 [ run ] triggered by Bot. Commit: |
|
PR_Github #31740 [ run ] completed with state
|
|
/bot run --stage-list "B300-PyTorch-1" |
|
PR_Github #31918 [ run ] triggered by Bot. Commit: |
|
PR_Github #31918 [ run ] completed with state
|
7693a70 to
75375ce
Compare
|
/bot run --stage-list "B300-PyTorch-1" |
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
I was adding model accuracy test back to And since the choice for whether using cute_dsl backend or not seems to be to minor for most users/developers, these arguments will not be added to |
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
e96130a to
3643bfd
Compare
|
/bot run |
|
PR_Github #33582 [ run ] triggered by Bot. Commit: |
|
PR_Github #33582 [ run ] completed with state
|
Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #33607 [ run ] triggered by Bot. Commit: |
|
PR_Github #33607 [ run ] completed with state
|
|
/bot run |
|
PR_Github #33662 [ run ] triggered by Bot. Commit: |
|
PR_Github #33662 [ run ] completed with state
|
|
/bot run |
|
PR_Github #33680 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.