Skip to content

Commit 41f28e2

Browse files
lirundongxjmxytabylinkin-huexiaoqiqi177bealwang
committed
add cuda.tile RMA norm kernel and bindings
Co-authored-by: Jinman Xie <[email protected]> Co-authored-by: Alexey Bylinkin <[email protected]> Co-authored-by: Qiqi Xiao <[email protected]> Co-authored-by: Biao Wang <[email protected]> Co-authored-by: Thomas Schmid <[email protected]> Signed-off-by: Rundong (David) Li <[email protected]>
1 parent 2d8245d commit 41f28e2

File tree

8 files changed

+1070
-1
lines changed

8 files changed

+1070
-1
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,5 @@ torchao>=0.14.1
8181
cuda-core
8282
llist
8383
dynamic_path_manager
84+
cuda-tile>=1.0.1
85+
nvidia-cuda-tileiras>=13.1
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
17+
18+
if IS_CUDA_TILE_AVAILABLE:
19+
from .rms_norm import rms_norm_kernel
20+
from .rms_norm import rms_norm_kernel_gather
21+
from .rms_norm import rms_norm_kernel_static_persistent
22+
from .rms_norm_fuse_residual import rms_norm_fuse_residual_kernel
23+
from .rms_norm_fuse_residual import rms_norm_fuse_residual_kernel_gather
24+
from .rms_norm_fuse_residual import rms_norm_fuse_residual_kernel_static_persistent
25+
26+
__all__ = [
27+
"rms_norm_kernel",
28+
"rms_norm_kernel_gather",
29+
"rms_norm_kernel_static_persistent",
30+
"rms_norm_fuse_residual_kernel",
31+
"rms_norm_fuse_residual_kernel_gather",
32+
"rms_norm_fuse_residual_kernel_static_persistent",
33+
]
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Adapted from https://github.com/NVIDIA/cutile-python/blob/main/test/kernels/rms_norm.py
17+
from ..cuda_tile_utils import IS_CUDA_TILE_AVAILABLE
18+
19+
if IS_CUDA_TILE_AVAILABLE:
20+
import cuda.tile as ct
21+
22+
23+
@ct.kernel
24+
def rms_norm_kernel(
25+
x,
26+
w,
27+
out,
28+
Rstd,
29+
N: ct.Constant[int],
30+
eps: ct.Constant[float],
31+
TILE_SIZE: ct.Constant[int],
32+
use_gemma: ct.Constant[bool],
33+
):
34+
"""Standard RMSNorm kernel for non-static persistent mode with tiled loads"""
35+
row = ct.bid(0)
36+
_rms = ct.full((1, TILE_SIZE), 0.0, dtype=ct.float32)
37+
num_tiles = ct.cdiv(x.shape[1], TILE_SIZE)
38+
39+
for j in range(0, num_tiles):
40+
xj = ct.load(
41+
x, index=(row, j), shape=(1, TILE_SIZE),
42+
allow_tma=False,
43+
latency=1,
44+
)
45+
xj = ct.astype(xj, ct.float32)
46+
_rms += xj * xj
47+
48+
# Calculate RMS Norm
49+
rms = ct.rsqrt(ct.sum(_rms, axis=1, keepdims=False) / N + eps)
50+
ct.store(Rstd, index=(row,), tile=rms)
51+
52+
for j in range(0, num_tiles):
53+
wj = ct.load(
54+
w, index=(j,), shape=(TILE_SIZE,),
55+
allow_tma=False,
56+
latency=1,
57+
)
58+
wj = ct.astype(wj, ct.float32)
59+
# Apply Gemma-style bias if enabled
60+
if use_gemma:
61+
wj = wj + 1.0
62+
xj = ct.load(
63+
x, index=(row, j), shape=(1, TILE_SIZE),
64+
allow_tma=False,
65+
latency=1,
66+
)
67+
xj = ct.astype(xj, ct.float32)
68+
yj = xj * rms * wj
69+
yj = ct.astype(yj, x.dtype)
70+
ct.store(
71+
out, index=(row, j), tile=yj,
72+
allow_tma=False,
73+
latency=1,
74+
)
75+
76+
77+
@ct.kernel
78+
def rms_norm_kernel_gather(
79+
x,
80+
w,
81+
out,
82+
Rstd,
83+
N: ct.Constant[int],
84+
eps: ct.Constant[float],
85+
TILE_SIZE: ct.Constant[int],
86+
use_gemma: ct.Constant[bool],
87+
):
88+
"""Standard RMSNorm kernel for non-static persistent mode with ptr loads"""
89+
row = ct.bid(0)
90+
_rms = ct.full((TILE_SIZE,), 0.0, dtype=ct.float32)
91+
num_tiles = ct.cdiv(N, TILE_SIZE)
92+
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)
93+
94+
for j in range(0, num_tiles):
95+
offs = j * TILE_SIZE + offsets
96+
xj = ct.gather(x, (row, offs), latency=1)
97+
xj = ct.astype(xj, ct.float32)
98+
_rms += xj * xj
99+
100+
# Calculate RMS Norm
101+
rms = ct.rsqrt(ct.sum(_rms, axis=0, keepdims=False) / N + eps)
102+
ct.scatter(Rstd, row, rms)
103+
104+
for j in range(0, num_tiles):
105+
offs = j * TILE_SIZE + offsets
106+
wj = ct.gather(w, offs, latency=1)
107+
wj = ct.astype(wj, ct.float32)
108+
# Apply Gemma-style bias if enabled
109+
if use_gemma:
110+
wj = wj + 1.0
111+
xj = ct.gather(x, (row, offs), latency=1)
112+
xj = ct.astype(xj, ct.float32)
113+
yj = xj * rms * wj
114+
yj = ct.astype(yj, x.dtype)
115+
ct.scatter(out, (row, offs), yj, latency=1)
116+
117+
118+
@ct.kernel
119+
def rms_norm_kernel_static_persistent(
120+
X, # Input tensor
121+
Y, # Output tensor
122+
W, # Weight tensor
123+
TILE_SIZE_M: ct.Constant[int], # 4 rows per block
124+
TILE_SIZE_N: ct.Constant[int], # columns per block
125+
eps: ct.Constant[float], # Epsilon value
126+
use_gemma: ct.Constant[bool], # Gemma-style weight bias
127+
):
128+
"""
129+
CuTile static persistent RMSNorm kernel that processes multiple blocks per program.
130+
Each program processes multiple blocks in a loop for better efficiency.
131+
"""
132+
# Get program ID
133+
pid = ct.bid(0)
134+
135+
# Infer tensor dimensions from input shape
136+
M = X.shape[0] # Number of rows
137+
N = X.shape[1] # Number of columns
138+
139+
# Calculate upper bound - number of row blocks to process
140+
upper_bound = (M + TILE_SIZE_M - 1) // TILE_SIZE_M
141+
142+
# Load weight vector once (shared across all blocks processed by this program)
143+
w = ct.load(W, index=(0,), shape=(TILE_SIZE_N,))
144+
w = ct.astype(w, ct.float32)
145+
# Apply Gemma-style bias if enabled
146+
if use_gemma:
147+
w = w + 1.0
148+
149+
# Static persistent loop: each program processes multiple blocks
150+
num_tiles_x = ct.num_blocks(0)
151+
for current_block in range(pid, upper_bound, num_tiles_x):
152+
# Load input tile
153+
x = ct.load(
154+
X, index=(current_block, 0), shape=(TILE_SIZE_M, TILE_SIZE_N),
155+
latency=10, # +2% perf from this hint
156+
)
157+
x = ct.astype(x, ct.float32)
158+
159+
# Step 1: Compute x^2
160+
x_squared = ct.mul(x, x)
161+
162+
# Step 2: Reduce sum along axis=1 (columns)
163+
x2_sum = ct.sum(
164+
x_squared, axis=1, keepdims=True
165+
) # Shape: [TILE_SIZE_M, 1]
166+
167+
# Step 3: Compute variance (divide by N)
168+
N_f32 = ct.full((TILE_SIZE_M, 1), N * 1.0, dtype=ct.float32)
169+
variance = ct.truediv(x2_sum, N_f32)
170+
171+
# Step 4: Add epsilon and compute rsqrt
172+
eps_tensor = ct.full((TILE_SIZE_M, 1), eps, dtype=ct.float32)
173+
variance_eps = ct.add(variance, eps_tensor)
174+
rsqrt_var = ct.rsqrt(variance_eps)
175+
176+
# Step 5: Apply normalization
177+
x_normalized = ct.mul(x, rsqrt_var)
178+
179+
# Step 6: Apply linear transformation
180+
# Broadcast weight to match input shape
181+
w_broadcasted = ct.reshape(w, (1, TILE_SIZE_N))
182+
b_broadcasted = ct.full((1, TILE_SIZE_N), 0.0, dtype=ct.float32)
183+
184+
# Apply linear transformation: y = x_normalized * w + b
185+
y = ct.mul(x_normalized, w_broadcasted)
186+
y = ct.add(y, b_broadcasted)
187+
188+
# Convert back to original dtype
189+
y = ct.astype(y, X.dtype)
190+
191+
# Store result
192+
ct.store(
193+
Y, index=(current_block, 0), tile=y,
194+
allow_tma=False, # +30% perf
195+
latency=3, # +3% perf from this hint
196+
)

0 commit comments

Comments
 (0)