-
Notifications
You must be signed in to change notification settings - Fork 681
fix: enable torch.autocast for TP parallelism without FSDP #2213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Remove overly conservative restriction that disabled mixed precision for TP-only configurations. torch.autocast operates at the operator level and is orthogonal to tensor parallelism. Before: TP-only training would show warning and disable mixed precision After: TP-only training uses torch.autocast for mixed precision Note: PP-only training uses schedule-based execution and doesn't use maybe_enable_amp (unchanged by this PR). Affected configurations: - TP-only (now enabled) - DDP-only (was already enabled) - Single-device (was already enabled) - FSDP/HSDP (unchanged - handled internally by fully_shard)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR removes an overly conservative restriction that disabled mixed precision training for Tensor Parallelism (TP) configurations without FSDP. The change enables torch.autocast for TP-only training, recognizing that autocast operates at the operator level and is orthogonal to the parallelism strategy.
Key changes:
- Simplified
maybe_enable_ampfunction logic to enable autocast for all non-FSDP configurations - Improved code comments to clarify when mixed precision is handled by FSDP vs AMP
- Added explanation that PP uses schedule-based execution and doesn't utilize this context
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you verified it works properly? Could you show evidence, in terms of param / activation / grad dtype, and throughput comparison with mixed precision off?
I vaguely remember that I've tried it before and it didn't work as expected.
https://huggingface.co/eousphoros/persona_eta_20b_131k This model was trained with TP=4 no fdsp. The output with autocast was inline with what I expected though I lack the depth of knowledge to formerly confirm this. |
|
Thanks. It's hard to tell from the plots that it is working properly. Also curious why would you use amp with TP but without FSDP? |
autocast is the core feature. But autocast with parallelisms is not actively maintained. We have seen performance gap between autocast + DDP and FSDP implementation with world_size being the same. IMO, if we can use FSDP we should use FSDP. |

Remove overly conservative restriction that disabled mixed precision for TP-only configurations. torch.autocast operates at the operator level and is orthogonal to tensor parallelism.
Before: TP-only training would show warning and disable mixed precision
After: TP-only training uses torch.autocast for mixed precision
Note: PP-only training uses schedule-based execution and doesn't use maybe_enable_amp (unchanged by this PR).
Affected configurations: