Skip to content

Conversation

@TianHao324
Copy link
Contributor

@TianHao324 TianHao324 commented Jan 7, 2026

Summary

Mainly to complete the adaptation of the tvd operator on the NPU:
1、Solving the operator ub overflow problem.
2、Use the chunking strategy to solve the problem where the grid maximum limit of trirton-ascend is 65535.
3、The data type is not supported for bf16, so all of them have been converted to f32.

Testing Done

ScreenShot_2026-01-06_171405_325 Verified on Ascend NPU 910B4:

tvc forward and backward pass tests

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@TianHao324 TianHao324 force-pushed the tvd branch 3 times, most recently from 1a31a76 to d8740a3 Compare January 7, 2026 06:51
@TianHao324
Copy link
Contributor Author

Hi @Tcc0403 @zheliuyu @noemotiovon
When you have a moment, could you help take a look at this code? Thanks!

@TianHao324 TianHao324 changed the title Tvd Add NPU support for the tvd operator Jan 7, 2026
# Fallback to desired block size if no best practice found (no tiling needed)
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

MAX_BATCH_PER_KERNEL = 65535 # 每个kernel最大处理量
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use English comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
BT, V = p.shape
# NPU does not support bfloat16 type
p = p.to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some previous operators, I recall that there are test cases using the bf16 dtype that can run correctly on the Ascend Triton backend. It would be helpful to explicitly document the reason why this operator requires conversion to float32.
Additionally, once the NPU backend adds native support for bf16, this cast should be removed. Consider adding a TODO comment here, for example:

# TODO: Remove float32 conversion after Ascend NPU supports bf16 natively

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

grads_chunk = grads[start:end]
labels_chunk = shift_labels[start:end] if has_label else torch.empty(1, device=p.device)

_tv_distance_kernel[grid](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to redesign the kernel to process multiple rows per program using a kernel-side loop with a fixed grid (e.g., 65535 programs).
Given that the current kernel is a simple per-row implementation, the overhead of multiple launches is likely negligible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

@TianHao324 TianHao324 force-pushed the tvd branch 7 times, most recently from 5ae19a8 to d7b5254 Compare January 8, 2026 02:19
Comment on lines 13 to 25
REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]

_REDUCTION_MODE_NONE = tl.constexpr(0)
_REDUCTION_MODE_SUM = tl.constexpr(1)
_REDUCTION_MODE_MEAN = tl.constexpr(2)
_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)

_str_to_reduction_mode = {
"none": _REDUCTION_MODE_NONE.value,
"sum": _REDUCTION_MODE_SUM.value,
"mean": _REDUCTION_MODE_MEAN.value,
"batchmean": _REDUCTION_MODE_BATCHMEAN.value,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to assign str to tl.constexpr variable directly without having to manually convert it to int. For instance, reduction in liger's cross_entropy is just a str. It doesn't need extra mapping onto int.

This mapping exists in the original tvd implementation, but I think we should remove it for readibility as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been modified and passed the test.

Comment on lines 162 to 170
if reduction == _REDUCTION_MODE_BATCHMEAN.value:
# TODO: Remove float32 conversion after Ascend NPU supports bf16 natively
return output_tensor.sum().to(torch.float32) / n_non_ignore, grads.to(torch.float32) / n_non_ignore
elif reduction == _REDUCTION_MODE_SUM.value:
return output_tensor.sum(dim=0), grads
elif reduction == _REDUCTION_MODE_MEAN.value:
return output_tensor.sum().to(torch.float32) / (n_non_ignore * V), grads.to(torch.float32) / (n_non_ignore * V)
else:
return output_tensor, grads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without str->int mapping, this part will be more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified


tl.store(grads_row_ptr + offsets, grad_res, mask=mask)

if reduction == _REDUCTION_MODE_NONE:
Copy link
Collaborator

@Tcc0403 Tcc0403 Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it will be a simple if reduction == "none":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

@TianHao324 TianHao324 changed the title Add NPU support for the tvd operator [NPU]: Add NPU support for the tvd operator Jan 9, 2026
@Tcc0403
Copy link
Collaborator

Tcc0403 commented Jan 12, 2026

Feel free to re-request review when it's ready

@TianHao324 TianHao324 force-pushed the tvd branch 2 times, most recently from cdae03e to 7b7f071 Compare January 12, 2026 08:21
@TianHao324
Copy link
Contributor Author

On NPU, bfloat16 execution of the tvd operator involves low-precision accumulation in the underlying matmul, and the stricter tolerance may therefore lead to false negatives in correctness tests.
For this reason, the tvd test use the more lenient tolerance (1e-8 -> 5e-5)

test_tvd.py::test_correctness[dtype0-5e-05-1e-06-batchmean-1-4096-32000] PASSED                                                                                                                [  0%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-batchmean-32-4096-1024] PASSED                                                                                                                [  1%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-batchmean-41-401-1271] PASSED                                                                                                                 [  1%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                      [  2%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-batchmean-3-423-32000] PASSED                                                                                                                 [  2%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-sum-1-4096-32000] PASSED                                                                                                                      [  3%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-sum-32-4096-1024] PASSED                                                                                                                      [  3%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-sum-41-401-1271] PASSED                                                                                                                       [  4%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                            [  4%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-sum-3-423-32000] PASSED                                                                                                                       [  5%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-mean-1-4096-32000] PASSED                                                                                                                     [  5%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-mean-32-4096-1024] PASSED                                                                                                                     [  6%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-mean-41-401-1271] PASSED                                                                                                                      [  6%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                           [  7%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-mean-3-423-32000] PASSED                                                                                                                      [  7%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-none-1-4096-32000] PASSED                                                                                                                     [  8%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-none-32-4096-1024] PASSED                                                                                                                     [  8%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-none-41-401-1271] PASSED                                                                                                                      [  9%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                           [  9%]
test_tvd.py::test_correctness[dtype0-5e-05-1e-06-none-3-423-32000] PASSED                                                                                                                      [ 10%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-batchmean-1-4096-32000] PASSED                                                                                                                [ 10%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-batchmean-32-4096-1024] PASSED                                                                                                                [ 11%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-batchmean-41-401-1271] PASSED                                                                                                                 [ 11%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                      [ 12%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-batchmean-3-423-32000] PASSED                                                                                                                 [ 12%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-sum-1-4096-32000] PASSED                                                                                                                      [ 13%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-sum-32-4096-1024] PASSED                                                                                                                      [ 13%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-sum-41-401-1271] PASSED                                                                                                                       [ 14%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                            [ 14%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-sum-3-423-32000] PASSED                                                                                                                       [ 15%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-mean-1-4096-32000] PASSED                                                                                                                     [ 15%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-mean-32-4096-1024] PASSED                                                                                                                     [ 16%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-mean-41-401-1271] PASSED                                                                                                                      [ 16%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                           [ 17%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-mean-3-423-32000] PASSED                                                                                                                      [ 17%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-none-1-4096-32000] PASSED                                                                                                                     [ 18%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-none-32-4096-1024] PASSED                                                                                                                     [ 18%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-none-41-401-1271] PASSED                                                                                                                      [ 19%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                           [ 19%]
test_tvd.py::test_correctness[dtype1-1e-08-1e-06-none-3-423-32000] PASSED                                                                                                                      [ 20%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-batchmean-1-4096-32000] PASSED                                                                                                       [ 20%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-batchmean-32-4096-1024] PASSED                                                                                                       [ 21%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-batchmean-41-401-1271] PASSED                                                                                                        [ 21%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                             [ 22%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-batchmean-3-423-32000] PASSED                                                                                                        [ 22%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-sum-1-4096-32000] PASSED                                                                                                             [ 23%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-sum-32-4096-1024] PASSED                                                                                                             [ 23%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-sum-41-401-1271] PASSED                                                                                                              [ 24%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                   [ 24%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-sum-3-423-32000] PASSED                                                                                                              [ 25%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-mean-1-4096-32000] PASSED                                                                                                            [ 25%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-mean-32-4096-1024] PASSED                                                                                                            [ 26%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-mean-41-401-1271] PASSED                                                                                                             [ 26%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                  [ 27%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-mean-3-423-32000] PASSED                                                                                                             [ 27%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-none-1-4096-32000] PASSED                                                                                                            [ 28%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-none-32-4096-1024] PASSED                                                                                                            [ 28%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-none-41-401-1271] PASSED                                                                                                             [ 29%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                  [ 29%]
test_tvd.py::test_correctness_not_last[dtype0-5e-05-1e-06-none-3-423-32000] PASSED                                                                                                             [ 30%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-batchmean-1-4096-32000] PASSED                                                                                                       [ 30%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-batchmean-32-4096-1024] PASSED                                                                                                       [ 31%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-batchmean-41-401-1271] PASSED                                                                                                        [ 31%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                             [ 32%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-batchmean-3-423-32000] PASSED                                                                                                        [ 32%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-sum-1-4096-32000] PASSED                                                                                                             [ 33%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-sum-32-4096-1024] PASSED                                                                                                             [ 33%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-sum-41-401-1271] PASSED                                                                                                              [ 34%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                   [ 34%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-sum-3-423-32000] PASSED                                                                                                              [ 35%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-mean-1-4096-32000] PASSED                                                                                                            [ 35%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-mean-32-4096-1024] PASSED                                                                                                            [ 36%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-mean-41-401-1271] PASSED                                                                                                             [ 36%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                  [ 37%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-mean-3-423-32000] PASSED                                                                                                             [ 37%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-1-4096-32000] PASSED                                                                                                            [ 38%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-32-4096-1024] PASSED                                                                                                            [ 38%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-41-401-1271] PASSED                                                                                                             [ 39%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                                  [ 39%]
test_tvd.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-3-423-32000] PASSED                                                                                                             [ 40%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-batchmean-1-4096-32000] PASSED                                                                                         [ 40%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-batchmean-32-4096-1024] PASSED                                                                                         [ 41%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-batchmean-41-401-1271] PASSED                                                                                          [ 41%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                               [ 42%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-batchmean-3-423-32000] PASSED                                                                                          [ 42%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-sum-1-4096-32000] PASSED                                                                                               [ 43%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-sum-32-4096-1024] PASSED                                                                                               [ 43%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-sum-41-401-1271] PASSED                                                                                                [ 44%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                     [ 44%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-sum-3-423-32000] PASSED                                                                                                [ 45%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-mean-1-4096-32000] PASSED                                                                                              [ 45%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-mean-32-4096-1024] PASSED                                                                                              [ 46%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-mean-41-401-1271] PASSED                                                                                               [ 46%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                    [ 47%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-mean-3-423-32000] PASSED                                                                                               [ 47%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-none-1-4096-32000] PASSED                                                                                              [ 48%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-none-32-4096-1024] PASSED                                                                                              [ 48%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-none-41-401-1271] PASSED                                                                                               [ 49%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                    [ 49%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype0-5e-05-1e-06-none-3-423-32000] PASSED                                                                                               [ 50%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-batchmean-1-4096-32000] PASSED                                                                                         [ 50%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-batchmean-32-4096-1024] PASSED                                                                                         [ 51%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-batchmean-41-401-1271] PASSED                                                                                          [ 51%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                               [ 52%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-batchmean-3-423-32000] PASSED                                                                                          [ 52%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-sum-1-4096-32000] PASSED                                                                                               [ 53%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-sum-32-4096-1024] PASSED                                                                                               [ 53%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-sum-41-401-1271] PASSED                                                                                                [ 54%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                     [ 54%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED                                                                                                [ 55%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-mean-1-4096-32000] PASSED                                                                                              [ 55%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-mean-32-4096-1024] PASSED                                                                                              [ 56%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-mean-41-401-1271] PASSED                                                                                               [ 56%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                    [ 57%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED                                                                                               [ 57%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-none-1-4096-32000] PASSED                                                                                              [ 58%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-none-32-4096-1024] PASSED                                                                                              [ 58%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-none-41-401-1271] PASSED                                                                                               [ 59%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                    [ 59%]
test_tvd.py::test_correctness_with_ignore_index[-100-dtype1-1e-08-1e-06-none-3-423-32000] PASSED                                                                                               [ 60%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-batchmean-1-4096-32000] PASSED                                                                                            [ 60%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-batchmean-32-4096-1024] PASSED                                                                                            [ 61%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-batchmean-41-401-1271] PASSED                                                                                             [ 61%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                  [ 62%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-batchmean-3-423-32000] PASSED                                                                                             [ 62%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-sum-1-4096-32000] PASSED                                                                                                  [ 63%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-sum-32-4096-1024] PASSED                                                                                                  [ 63%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-sum-41-401-1271] PASSED                                                                                                   [ 64%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                        [ 64%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-sum-3-423-32000] PASSED                                                                                                   [ 65%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-mean-1-4096-32000] PASSED                                                                                                 [ 65%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-mean-32-4096-1024] PASSED                                                                                                 [ 66%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-mean-41-401-1271] PASSED                                                                                                  [ 66%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 67%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-mean-3-423-32000] PASSED                                                                                                  [ 67%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-none-1-4096-32000] PASSED                                                                                                 [ 68%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-none-32-4096-1024] PASSED                                                                                                 [ 68%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-none-41-401-1271] PASSED                                                                                                  [ 69%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 69%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype0-5e-05-1e-06-none-3-423-32000] PASSED                                                                                                  [ 70%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-batchmean-1-4096-32000] PASSED                                                                                            [ 70%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-batchmean-32-4096-1024] PASSED                                                                                            [ 71%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-batchmean-41-401-1271] PASSED                                                                                             [ 71%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                  [ 72%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-batchmean-3-423-32000] PASSED                                                                                             [ 72%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-sum-1-4096-32000] PASSED                                                                                                  [ 73%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-sum-32-4096-1024] PASSED                                                                                                  [ 73%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-sum-41-401-1271] PASSED                                                                                                   [ 74%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                        [ 74%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED                                                                                                   [ 75%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-mean-1-4096-32000] PASSED                                                                                                 [ 75%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-mean-32-4096-1024] PASSED                                                                                                 [ 76%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-mean-41-401-1271] PASSED                                                                                                  [ 76%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 77%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED                                                                                                  [ 77%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-none-1-4096-32000] PASSED                                                                                                 [ 78%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-none-32-4096-1024] PASSED                                                                                                 [ 78%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-none-41-401-1271] PASSED                                                                                                  [ 79%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 79%]
test_tvd.py::test_correctness_with_ignore_index[0-dtype1-1e-08-1e-06-none-3-423-32000] PASSED                                                                                                  [ 80%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-batchmean-1-4096-32000] PASSED                                                                                            [ 80%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-batchmean-32-4096-1024] PASSED                                                                                            [ 81%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-batchmean-41-401-1271] PASSED                                                                                             [ 81%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                  [ 82%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-batchmean-3-423-32000] PASSED                                                                                             [ 82%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-sum-1-4096-32000] PASSED                                                                                                  [ 83%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-sum-32-4096-1024] PASSED                                                                                                  [ 83%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-sum-41-401-1271] PASSED                                                                                                   [ 84%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                        [ 84%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-sum-3-423-32000] PASSED                                                                                                   [ 85%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-mean-1-4096-32000] PASSED                                                                                                 [ 85%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-mean-32-4096-1024] PASSED                                                                                                 [ 86%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-mean-41-401-1271] PASSED                                                                                                  [ 86%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 87%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-mean-3-423-32000] PASSED                                                                                                  [ 87%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-none-1-4096-32000] PASSED                                                                                                 [ 88%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-none-32-4096-1024] PASSED                                                                                                 [ 88%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-none-41-401-1271] PASSED                                                                                                  [ 89%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 89%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype0-5e-05-1e-06-none-3-423-32000] PASSED                                                                                                  [ 90%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-batchmean-1-4096-32000] PASSED                                                                                            [ 90%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-batchmean-32-4096-1024] PASSED                                                                                            [ 91%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-batchmean-41-401-1271] PASSED                                                                                             [ 91%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-batchmean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                  [ 92%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-batchmean-3-423-32000] PASSED                                                                                             [ 92%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-sum-1-4096-32000] PASSED                                                                                                  [ 93%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-sum-32-4096-1024] PASSED                                                                                                  [ 93%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-sum-41-401-1271] PASSED                                                                                                   [ 94%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-sum-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                        [ 94%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-sum-3-423-32000] PASSED                                                                                                   [ 95%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-mean-1-4096-32000] PASSED                                                                                                 [ 95%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-mean-32-4096-1024] PASSED                                                                                                 [ 96%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-mean-41-401-1271] PASSED                                                                                                  [ 96%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-mean-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 97%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-mean-3-423-32000] PASSED                                                                                                  [ 97%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-none-1-4096-32000] PASSED                                                                                                 [ 98%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-none-32-4096-1024] PASSED                                                                                                 [ 98%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-none-41-401-1271] PASSED                                                                                                  [ 99%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-none-1-4096-128256] SKIPPED (This test requires a GPU with at least 36GB of memory)                                       [ 99%]
test_tvd.py::test_correctness_with_ignore_index[1-dtype1-1e-08-1e-06-none-3-423-32000] PASSED                                                                                                  [100%]

@TianHao324
Copy link
Contributor Author

Meanwhile, in the checkstyle check, I did not modify any related files. This might be caused by other commits.

I001 [*] Import block is un-sorted or un-formatted
  --> src/liger_kernel/transformers/model/gemma3.py:1:1
   |
 1 | / from typing import Optional
 2 | | from typing import Tuple
 3 | | from typing import Union
 4 | |
 5 | | import torch
 6 | | import torch.nn as nn
 7 | |
 8 | | from transformers.cache_utils import Cache
 9 | | from transformers.cache_utils import Cache
10 | | from transformers.utils import logging
11 | |
12 | | from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
13 | | from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
14 | | from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
15 | | from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
16 | | from liger_kernel.transformers.model.output_classes import LigerGemma3CausalLMOutputWithPast
   | |____________________________________________________________________________________________^
17 |
18 |   logger = logging.get_logger(__name__)
   |
help: Organize imports

F811 [*] Redefinition of unused `Cache` from line 8
  --> src/liger_kernel/transformers/model/gemma3.py:8:38
   |
 6 | import torch.nn as nn
 7 |
 8 | from transformers.cache_utils import Cache
   |                                      ----- previous definition of `Cache` here
 9 | from transformers.cache_utils import Cache
   |                                      ^^^^^ `Cache` redefined here
10 | from transformers.utils import logging
   |
help: Remove definition: `Cache`

Found 2 errors.
[*] 2 fixable with the `--fix` option.
224 files already formatted
Found 1 error (1 fixed, 0 remaining).
224 files left unchanged
make: *** [Makefile:20: checkstyle] Error 1

@TianHao324
Copy link
Contributor Author

@Tcc0403 It's ready for review, if you have time.

Copy link
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Tcc0403 Tcc0403 merged commit 70117b9 into linkedin:main Jan 12, 2026
4 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants