-
Notifications
You must be signed in to change notification settings - Fork 681
Remove unnecessary token padding for MoE in BF16 mode #2255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
| TOKEN_GROUP_ALIGN_SIZE_M = 8 | ||
| ValidTokenGroupAlignmentSize = Literal[8, 16, 32] | ||
| TOKEN_GROUP_ALIGN_SIZE_M = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix is "soft", in the sense that the padding code path still exists for bf16.
I wonder whether it's viable to go one step further -- remove all padding logic for bf16 and move padding logic to quantized paths only. cc @danielvegamyhre
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way is fine, the 8 token alignment is needed if we want to use TMA in any kernels operating on each token group (8*2 bytes per elem = 16 byte alignment). However, if we are only doing that in the low precision code path, then there's no reason to pad.
Feel free to remove bf16 padding entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For bf16 8-token alignment is not needed anywhere, see
import torch
import torch.nn.functional as F
x = torch.randn(2048, 4096, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
w = torch.randn(2, 4096, 7168, device="cuda", dtype=torch.bfloat16).requires_grad_(True)
# odd offsets
offs = torch.tensor([1023, 2048], device="cuda", dtype=torch.int32)
out = F.grouped_mm(x, w, offs=offs)
gO = torch.rand_like(out)
out.backward(gO)
# check that gradients are computed
print(x.grad.sum(), w.grad.sum())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre
Right now we are mixing padding and permutation into one kernel. Since bf16 doesn't require padding, I wonder if it makes sense to move padding to quantization kernel? The argument is that the kernel itself should be general and not require user to do padding from outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure if we agree move the padding logic to quant paths then i will refactor to remove TOKEN_GROUP_ALIGN_SIZE_M in torchtitan.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l we have a version of the permutation and pad/fill kernel in torchao now, used in the MXFP8 EP primitives. It is not fused with quantization though. To clarify, are you asking if we can delete the permute+pad kernel from torchtitan and replace it with fused permute+pad+quantize kernel in torchao?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre My request is that we remove padding from torchtitan entirely, while keeping correctness.
In the past we have the permute+pad kernel to avoid d2h sync on the padding part. Now that if we no longer need padding for bf16, I'd hope we remove the kernel altogether, but that requires torchao to handle padding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is doable
|
Solar Open-102B technical report is very interesting @rakkit, thanks for sharing it! |
fix of #2225
Context: Solar Open-102B points out in BF16 mode, Expert parallel did unnecessary token padding (ps. also non-EP case).
This PR set
TOKEN_GROUP_ALIGN_SIZE_M=1by Default.indices_padding_wrapper_permutetakesTOKEN_GROUP_ALIGN_SIZE_M=1andpadded_max_len = x.shape[0]that can avoid any padding.Test:
Original implement (with


TOKEN_GROUP_ALIGN_SIZE_M=8)CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2And with this PR


CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2