Skip to content

Conversation

@rakkit
Copy link
Contributor

@rakkit rakkit commented Jan 20, 2026

fix of #2225

Context: Solar Open-102B points out in BF16 mode, Expert parallel did unnecessary token padding (ps. also non-EP case).
This PR set TOKEN_GROUP_ALIGN_SIZE_M=1 by Default.

Test:

Original implement (with TOKEN_GROUP_ALIGN_SIZE_M=8)
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10
image
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2
image

And with this PR
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10
image
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2
image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 20, 2026

TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
TOKEN_GROUP_ALIGN_SIZE_M = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is "soft", in the sense that the padding code path still exists for bf16.

I wonder whether it's viable to go one step further -- remove all padding logic for bf16 and move padding logic to quantized paths only. cc @danielvegamyhre

Copy link
Contributor

@danielvegamyhre danielvegamyhre Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way is fine, the 8 token alignment is needed if we want to use TMA in any kernels operating on each token group (8*2 bytes per elem = 16 byte alignment). However, if we are only doing that in the low precision code path, then there's no reason to pad.

Feel free to remove bf16 padding entirely.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bf16 8-token alignment is not needed anywhere, see

import torch
import torch.nn.functional as F
x = torch.randn(2048, 4096, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
w = torch.randn(2, 4096, 7168, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
# odd offsets
offs = torch.tensor([1023, 2048], device="cuda", dtype=torch.int32)
out = F.grouped_mm(x, w, offs=offs)
gO = torch.rand_like(out)
out.backward(gO)
# check that gradients are computed
print(x.grad.sum(), w.grad.sum())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre
Right now we are mixing padding and permutation into one kernel. Since bf16 doesn't require padding, I wonder if it makes sense to move padding to quantization kernel? The argument is that the kernel itself should be general and not require user to do padding from outside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure if we agree move the padding logic to quant paths then i will refactor to remove TOKEN_GROUP_ALIGN_SIZE_M in torchtitan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l we have a version of the permutation and pad/fill kernel in torchao now, used in the MXFP8 EP primitives. It is not fused with quantization though. To clarify, are you asking if we can delete the permute+pad kernel from torchtitan and replace it with fused permute+pad+quantize kernel in torchao?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre My request is that we remove padding from torchtitan entirely, while keeping correctness.

In the past we have the permute+pad kernel to avoid d2h sync on the padding part. Now that if we no longer need padding for bf16, I'd hope we remove the kernel altogether, but that requires torchao to handle padding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is doable

@danielvegamyhre
Copy link
Contributor

Solar Open-102B technical report is very interesting @rakkit, thanks for sharing it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants