Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/liger_kernel/ops/backends/_ascend/ops/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,26 @@ def geglu_forward(a, b):
dtype_size = a.element_size()
# GEGLU forward tiling strategy:
# - Calculates maximum safe block size based on UB capacity
# - Memory analysis:
# * Inputs: a, b
# * Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a
# * Output: c
# * Total: ~7x * BLOCK_SIZE * dtype_size
# - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
# - Memory analysis (only buffers that occupy UB, excluding temporary variables):
# * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
# * Output: c_row (dtype_size bytes)
# * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
# and don't occupy UB since they are only used once
# * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
# * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
# - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
# - shapes: ((n_cols,),)
# - tiling_dims: (0,) means first dimension can be tiled
# - Returns: ((block_size,),)
shapes = ((n_cols,),)
if dtype_size == 2:
memory_multiplier = 4.0
else:
memory_multiplier = 3.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.80,
dtype_size=dtype_size,
memory_multiplier=7.0,
memory_multiplier=memory_multiplier,
shapes=shapes,
tiling_dims=(0,),
)
Expand Down Expand Up @@ -187,18 +193,34 @@ def geglu_backward(a, b, dc):
dtype_size = dc.element_size()
# GEGLU backward tiling strategy:
# - Calculates maximum safe block size based on UB capacity
# - Memory analysis:
# * More intermediates for gradient computation compared to forward
# * Total: ~10x * BLOCK_SIZE * dtype_size
# - Uses memory_multiplier=10.0 * BLOCK_SIZE * dtype_size * 8 bits for safety
# - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
# At this point, the following buffers simultaneously occupy UB:
# 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
# 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
# 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
# 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
# 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
# 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
# Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
# Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
# and don't occupy UB since they are only used once
# * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
# = 20 bytes/element, ratio = 20/2 = 10.0
# * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
# = 24 bytes/element, ratio = 24/4 = 6.0
# - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
# - shapes: ((n_cols,),)
# - tiling_dims: (0,) means first dimension can be tiled
# - Returns: ((block_size,),)
shapes = ((n_cols,),)
if dtype_size == 2:
memory_multiplier = 10.0
else:
memory_multiplier = 6.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.80,
dtype_size=dtype_size,
memory_multiplier=10.0,
memory_multiplier=memory_multiplier,
shapes=shapes,
tiling_dims=(0,),
)
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/backends/_ascend/ub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def compute_default_tiling_strategy(
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
Must be provided. If None or <= 0, defaults to 4 (float32).
memory_multiplier: Memory multiplier for estimating peak memory usage.
- For GEGLU: typically 10.0 for backward, 7.0 for forward
- For GEGLU: typically 10.0 for backward, 4.0 for forward
- For ROPE: typically 3.0
If None, defaults to 10.0 (conservative estimate).
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
Expand Down
Loading