-
Notifications
You must be signed in to change notification settings - Fork 475
[NPU]: Add NPU support for the tvd operator #998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1a31a76 to
d8740a3
Compare
|
Hi @Tcc0403 @zheliuyu @noemotiovon |
| # Fallback to desired block size if no best practice found (no tiling needed) | ||
| BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) | ||
|
|
||
| MAX_BATCH_PER_KERNEL = 65535 # 每个kernel最大处理量 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use English comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
| def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label): | ||
| BT, V = p.shape | ||
| # NPU does not support bfloat16 type | ||
| p = p.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some previous operators, I recall that there are test cases using the bf16 dtype that can run correctly on the Ascend Triton backend. It would be helpful to explicitly document the reason why this operator requires conversion to float32.
Additionally, once the NPU backend adds native support for bf16, this cast should be removed. Consider adding a TODO comment here, for example:
# TODO: Remove float32 conversion after Ascend NPU supports bf16 natively
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
| grads_chunk = grads[start:end] | ||
| labels_chunk = shift_labels[start:end] if has_label else torch.empty(1, device=p.device) | ||
|
|
||
| _tv_distance_kernel[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible to redesign the kernel to process multiple rows per program using a kernel-side loop with a fixed grid (e.g., 65535 programs).
Given that the current kernel is a simple per-row implementation, the overhead of multiple launches is likely negligible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edited
5ae19a8 to
d7b5254
Compare
| REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] | ||
|
|
||
| _REDUCTION_MODE_NONE = tl.constexpr(0) | ||
| _REDUCTION_MODE_SUM = tl.constexpr(1) | ||
| _REDUCTION_MODE_MEAN = tl.constexpr(2) | ||
| _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) | ||
|
|
||
| _str_to_reduction_mode = { | ||
| "none": _REDUCTION_MODE_NONE.value, | ||
| "sum": _REDUCTION_MODE_SUM.value, | ||
| "mean": _REDUCTION_MODE_MEAN.value, | ||
| "batchmean": _REDUCTION_MODE_BATCHMEAN.value, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to assign str to tl.constexpr variable directly without having to manually convert it to int. For instance, reduction in liger's cross_entropy is just a str. It doesn't need extra mapping onto int.
This mapping exists in the original tvd implementation, but I think we should remove it for readibility as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been modified and passed the test.
| if reduction == _REDUCTION_MODE_BATCHMEAN.value: | ||
| # TODO: Remove float32 conversion after Ascend NPU supports bf16 natively | ||
| return output_tensor.sum().to(torch.float32) / n_non_ignore, grads.to(torch.float32) / n_non_ignore | ||
| elif reduction == _REDUCTION_MODE_SUM.value: | ||
| return output_tensor.sum(dim=0), grads | ||
| elif reduction == _REDUCTION_MODE_MEAN.value: | ||
| return output_tensor.sum().to(torch.float32) / (n_non_ignore * V), grads.to(torch.float32) / (n_non_ignore * V) | ||
| else: | ||
| return output_tensor, grads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without str->int mapping, this part will be more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified
|
|
||
| tl.store(grads_row_ptr + offsets, grad_res, mask=mask) | ||
|
|
||
| if reduction == _REDUCTION_MODE_NONE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And it will be a simple if reduction == "none":
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified
|
Feel free to re-request review when it's ready |
cdae03e to
7b7f071
Compare
|
On NPU, bfloat16 execution of the tvd operator involves low-precision accumulation in the underlying matmul, and the stricter tolerance may therefore lead to false negatives in correctness tests. |
|
Meanwhile, in the checkstyle check, I did not modify any related files. This might be caused by other commits. |
|
@Tcc0403 It's ready for review, if you have time. |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Summary
Mainly to complete the adaptation of the tvd operator on the NPU:
1、Solving the operator ub overflow problem.
2、Use the chunking strategy to solve the problem where the grid maximum limit of trirton-ascend is 65535.
3、The data type is not supported for bf16, so all of them have been converted to f32.
Testing Done
tvc forward and backward pass tests
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence