diff --git a/src/liger_kernel/ops/backends/_ascend/ops/geglu.py b/src/liger_kernel/ops/backends/_ascend/ops/geglu.py index 0a8cedbae..ef7ee51a7 100644 --- a/src/liger_kernel/ops/backends/_ascend/ops/geglu.py +++ b/src/liger_kernel/ops/backends/_ascend/ops/geglu.py @@ -130,20 +130,26 @@ def geglu_forward(a, b): dtype_size = a.element_size() # GEGLU forward tiling strategy: # - Calculates maximum safe block size based on UB capacity - # - Memory analysis: - # * Inputs: a, b - # * Intermediates: a_cubed, tanh_arg, tanh_result, geglu_a - # * Output: c - # * Total: ~7x * BLOCK_SIZE * dtype_size - # - Uses memory_multiplier=7.0 * BLOCK_SIZE * dtype_size * 8 bits for safety + # - Memory analysis (only buffers that occupy UB, excluding temporary variables): + # * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes) + # * Output: c_row (dtype_size bytes) + # * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers + # and don't occupy UB since they are only used once + # * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0 + # * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0 + # - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits # - shapes: ((n_cols,),) # - tiling_dims: (0,) means first dimension can be tiled # - Returns: ((block_size,),) shapes = ((n_cols,),) + if dtype_size == 2: + memory_multiplier = 4.0 + else: + memory_multiplier = 3.0 tile_shapes = compute_default_tiling_strategy( safety_margin=0.80, dtype_size=dtype_size, - memory_multiplier=7.0, + memory_multiplier=memory_multiplier, shapes=shapes, tiling_dims=(0,), ) @@ -187,18 +193,34 @@ def geglu_backward(a, b, dc): dtype_size = dc.element_size() # GEGLU backward tiling strategy: # - Calculates maximum safe block size based on UB capacity - # - Memory analysis: - # * More intermediates for gradient computation compared to forward - # * Total: ~10x * BLOCK_SIZE * dtype_size - # - Uses memory_multiplier=10.0 * BLOCK_SIZE * dtype_size * 8 bits for safety + # - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation) + # At this point, the following buffers simultaneously occupy UB: + # 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes + # 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32) + # 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes + # 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104 + # 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98 + # 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109) + # Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB + # Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers + # and don't occupy UB since they are only used once + # * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4) + # = 20 bytes/element, ratio = 20/2 = 10.0 + # * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4) + # = 24 bytes/element, ratio = 24/4 = 6.0 + # - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits # - shapes: ((n_cols,),) # - tiling_dims: (0,) means first dimension can be tiled # - Returns: ((block_size,),) shapes = ((n_cols,),) + if dtype_size == 2: + memory_multiplier = 10.0 + else: + memory_multiplier = 6.0 tile_shapes = compute_default_tiling_strategy( safety_margin=0.80, dtype_size=dtype_size, - memory_multiplier=10.0, + memory_multiplier=memory_multiplier, shapes=shapes, tiling_dims=(0,), ) diff --git a/src/liger_kernel/ops/backends/_ascend/ub_manager.py b/src/liger_kernel/ops/backends/_ascend/ub_manager.py index b30e5f839..36d10eba2 100644 --- a/src/liger_kernel/ops/backends/_ascend/ub_manager.py +++ b/src/liger_kernel/ops/backends/_ascend/ub_manager.py @@ -241,7 +241,7 @@ def compute_default_tiling_strategy( dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32). Must be provided. If None or <= 0, defaults to 4 (float32). memory_multiplier: Memory multiplier for estimating peak memory usage. - - For GEGLU: typically 10.0 for backward, 7.0 for forward + - For GEGLU: typically 10.0 for backward, 4.0 for forward - For ROPE: typically 3.0 If None, defaults to 10.0 (conservative estimate). shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.