-
Notifications
You must be signed in to change notification settings - Fork 681
[mxfp8 moe training] mxfp8 all to all #2250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/3
Are you sure you want to change the base?
Conversation
stack-info: PR: #2250, branch: danielvegamyhre/stack/4
1763af6 to
8323531
Compare
stack-info: PR: #2250, branch: danielvegamyhre/stack/4
8323531 to
518835a
Compare
|
@tianyu-l I will polish this up but any earlier thoughts on the high level design? |
| return wrapper | ||
|
|
||
|
|
||
| def get_a2a_splits( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the current code organization, this function is a distributed util, so probably should be in expert_parallel.py.
| https://github.com/deepseek-ai/DeepEP. | ||
| """ | ||
|
|
||
| expert_parallel_a2a_dispatch_fwd_precision: Literal["default", "mxfp8"] = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not a net benefit to enable mxfp8 dispatch in fwd and mxfp8 combine in backward? If they are, they don't need to be configurable -- we can just enable them whenever EP and mxfp8 grouped mm is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is but there is an important limitation, which is that mxfp8 grouped mm recipe must be mxfp8_wgrad_with_hp (computes forward output and dgrad in mxfp8, computes wgrad in high precision). It is not compatible with mxfp8 recipe yet (uses mxfp8 for all output/dgrad/wgrad).
We could automatically enable it when the recipe is mxfp8_wgrad_with_hp perhaps?
Context on this requirement if you are curious:
This is because if inputs come in pre-quantized along dim0, in the backward pass wgrad = grad_out_t @ input where we need inputs quantized along dim1, we would need to dequant along dim0 then requant along dim1. This is definitely an option to do, especially if we have a fast fused kernel for this, however, there is some debate about the numerical implications of doing this (as it is not equivalent to doing dim1 quant on the original bf16 tensor).
Therefore, for now we assert the recipe if mxfp8_wgrad_with_recipe pending (1) the fast dequant/requant kernel, and (2) numerics experiments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is but there is an important limitation, which is that mxfp8 grouped mm recipe must be mxfp8_wgrad_with_hp
Then we should couple it with mxfp8_wgrad_with_hp, unless you see reasons not to?
This is because if inputs come in pre-quantized along dim0, in the backward pass wgrad = grad_out_t @ input where we need inputs quantized along dim1, we would need to dequant along dim0 then requant along dim1.
I get that wgrad_with_hp makes sense, but I didn't get what's its relationship to doing fwd combine / bwd dispatch of dgrad in high-precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then we should couple it with mxfp8_wgrad_with_hp, unless you see reasons not to?
I agree this sounds good, the only blocker on that right now is that torch.compile is working in unit tests but not in torchtitan e2e training with these mxfp8 EP building blocks. cc @bdhirsh the nonstrict_trace dynamo decorator works in the torchao integration test w/ compile, but in torchtitan the same "tensor metadata mismatch" rears its head again. Will DM you about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a diagram to describe why things are breaking:

torch.compile has a hard requirement today, where when you compile a given region of code, the forward outputs and their corresponding gradients (the backward inputs to that compiled region) must have the same subclass type.
You can see from this image that in the mxfp8 codepath, both the GroupedExperts.forward, and the MXFP8ExpertParallel._token_dispatch violate this assumption.
That doesn't mean that we can't support this case. But it does mean that we are unable to compile these two blocks of code in isolation. Instead, our options are either:
(1) don't compile the token_dispatch or the GroupedExperts.forward at all
(2) compile a single graph, containing all of _token_dispatch + GroupedExperts.forward + _token_combine
I have a short diff for both 1 and 2, although I think there are a few things worth thinking more about.
diff for not compiling GroupedExperts.forward: https://gist.github.com/bdhirsh/970a671b84c35cc95a76f33657ca4d69
diff for compiling all 3 into a single graph: https://gist.github.com/bdhirsh/8b0ced2f6381b52eebd56fdff8e62093 This was a bit strange - it looks like when we try to compile GroupedExperts, dynamo ends up creating a separate graph for the GroupedExperts.forward and the module's forward pre-hooks / post-hooks. I "fixed" it by moving the GroupedExperts into an nn module wrapper and compiling that, but I imagine this isn't actually landable since it will e.g. muck up the state dict. I also saw some previous comments from @xmfan indicating that there were previous problems trying to compile _token_dispatch/combine: #1940. I didn't understand the issue though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, the best option would be if we can tweak FSDP to handle multiple device meshes, so we can avoid having to FSDP-wrap the experts region (cc @weifengpy). Since I think that would solve this particular issue, and would also avoid a lot of the other issues that we've seen in titan around compiling MoE.
I'm not sure what the timeline for that is, so depending on @danielvegamyhre 's timeline, we might want to either just skip compiling the grouped_mm to unblock, or figure out how to land (2) above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clear diagram! Are there implementation changes we can do to make MXTensor <-> bf16 conversions to be moved into a different graph? Otherwise, I think the FSDP2 changes are necessary to compile a single graph with dispatch/experts/combine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@danielvegamyhre can speak to that better than me, but I would imagine... probably not, unless you are willing to completely abandon the structure today of "_token_dispatch/combine" are module hooks that are logically separated from the grouped_mm forward call".
Otherwise, I think the FSDP2 changes are necessary to compile a single graph with dispatch/experts/combine
I agree that this feels like an overall much cleaner state. One dumb question I have: I was able to get token_dispatch + grouped_mm + token_combine into a single graph with this change: https://gist.github.com/bdhirsh/8b0ced2f6381b52eebd56fdff8e62093. I agree it is ugly (and potentially not landable because it changes the state dict), but it seemed to work for me on Daniel's example. It sounds like there are other issues you ran into trying to compile the token_dispatch/combine - do you know what they are and what caused them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seemed to be something with the unbacked shapes, I'm not sure if the error is still present. I believe the main consideration for module hook was to provide the same clean model definition for both ep and non-ep, so changing APIs seems okay as long as that we can ensure that
Stacked PRs:
[mxfp8 moe training] mxfp8 all to all
WIP