diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7790ab6683..cfe5fc19b3 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -250,24 +250,19 @@ def maybe_enable_amp( parallel_dims: ParallelDims, mixed_precision_param: str, device_type: str ) -> contextlib.AbstractContextManager[None]: if parallel_dims.fsdp_enabled: - # FSDP handles mixed precision internally + # FSDP handles mixed precision internally via MixedPrecisionPolicy logger.info("Mixed precision training is handled by fully_shard") return contextlib.nullcontext() else: - if parallel_dims.tp_enabled or parallel_dims.pp_enabled: - logger.warning( - "Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP is enabled." - ) - logger.info("Mixed precision training is disabled") - return contextlib.nullcontext() - else: - # the following code will only be executed for DDP or single-device training - logger.info("Mixed precision training is handled by AMP") - # pyrefly: ignore [bad-return] - return torch.autocast( - device_type, - dtype=TORCH_DTYPE_MAP[mixed_precision_param], - ) + # Enable autocast for non-FSDP cases (DDP, TP-only, single-device) + # Note: PP uses its own schedule-based execution and doesn't use this context + # torch.autocast works at the operator level and is orthogonal to parallelism strategy + logger.info("Mixed precision training is handled by AMP") + # pyrefly: ignore [bad-return] + return torch.autocast( + device_type, + dtype=TORCH_DTYPE_MAP[mixed_precision_param], + ) def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"):