Skip to content

Commit 5bc77c5

Browse files
committed
chg: disable moe amax routing remedy when router group is not None.
chg: remove global barrier in SequentialMLP Signed-off-by: Chenhan Yu <[email protected]>
1 parent 945ee02 commit 5bc77c5

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def sync_moe_local_experts_amax(self):
581581
This function is called to synchronize the amax values across local experts s.t. all localexperts will
582582
share the same amax.
583583
"""
584-
torch.distributed.barrier()
584+
# torch.distributed.barrier()
585585
# Collect amax from all local experts
586586
amax_dict = {}
587587
for expert in self.local_experts:
@@ -754,8 +754,11 @@ def _setup(self):
754754

755755
def forward(self, hidden_states):
756756
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
757-
original_top_k = self.router.topk
758-
self.router.topk = self.router.num_experts
759-
super().forward(hidden_states)
760-
self.router.topk = original_top_k
757+
if self.config.moe_router_num_groups is None:
758+
original_top_k = self.router.topk
759+
self.router.topk = self.router.num_experts
760+
super().forward(hidden_states)
761+
self.router.topk = original_top_k
762+
else:
763+
super().forward(hidden_states)
761764
return super().forward(hidden_states)

0 commit comments

Comments
 (0)