File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed
modelopt/torch/quantization/plugins Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -581,7 +581,7 @@ def sync_moe_local_experts_amax(self):
581581 This function is called to synchronize the amax values across local experts s.t. all localexperts will
582582 share the same amax.
583583 """
584- torch .distributed .barrier ()
584+ # torch.distributed.barrier()
585585 # Collect amax from all local experts
586586 amax_dict = {}
587587 for expert in self .local_experts :
@@ -754,8 +754,11 @@ def _setup(self):
754754
755755 def forward (self , hidden_states ):
756756 if any (getattr (m , "_if_calib" , False ) for m in self .experts .modules ()):
757- original_top_k = self .router .topk
758- self .router .topk = self .router .num_experts
759- super ().forward (hidden_states )
760- self .router .topk = original_top_k
757+ if self .config .moe_router_num_groups is None :
758+ original_top_k = self .router .topk
759+ self .router .topk = self .router .num_experts
760+ super ().forward (hidden_states )
761+ self .router .topk = original_top_k
762+ else :
763+ super ().forward (hidden_states )
761764 return super ().forward (hidden_states )
You can’t perform that action at this time.
0 commit comments