Skip to content
93 changes: 93 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,78 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
return self.w2_linear[expert_idx](x1)


class _QuantQwen3VLMoeTextExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3VLMoeTextExperts by using nn.Linear layers."""
from accelerate import init_empty_weights

dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device

def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

with init_empty_weights():
gate_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
for _ in range(self.num_experts)
]
)
up_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
for _ in range(self.num_experts)
]
)
down_proj = nn.ModuleList(
[
nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
for _ in range(self.num_experts)
]
)

for idx in range(self.num_experts):
_copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, : self.expert_dim].T)
_copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim :].T)
_copy_weight(down_proj[idx], self.down_proj[idx, :].T)

delattr(self, "gate_up_proj")
delattr(self, "down_proj")
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj

Comment on lines 574 to 625
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First check if transformers is available and inspect Qwen3VLMoeTextExperts
python3 << 'EOF'
import sys
try:
    # Try to import the class
    from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts
    import inspect
    
    print("=== Qwen3VLMoeTextExperts Class Found ===")
    
    # Get source code
    try:
        source = inspect.getsource(Qwen3VLMoeTextExperts)
        # Check for expert_dim
        if 'expert_dim' in source:
            print("✓ expert_dim FOUND in class source")
        else:
            print("✗ expert_dim NOT found in class source")
        
        # Check for intermediate_size
        if 'intermediate_size' in source:
            print("✓ intermediate_size found in class source")
        
        # Check __init__ signature
        init_sig = inspect.signature(Qwen3VLMoeTextExperts.__init__)
        print(f"__init__ parameters: {list(init_sig.parameters.keys())}")
        
    except Exception as e:
        print(f"Could not get source: {e}")
        
except ImportError as e:
    print(f"✗ Could not import Qwen3VLMoeTextExperts: {e}")
    print("\nTrying alternative import paths...")
    
    try:
        from transformers import AutoModel
        print("AutoModel available, but Qwen3VLMoeTextExperts not directly importable")
    except:
        print("Transformers not fully available")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 152


🏁 Script executed:

#!/bin/bash
# Also search the reviewed file for any expert_dim references
echo "=== Searching reviewed file for expert_dim references ==="
if [ -f "modelopt/torch/quantization/plugins/huggingface.py" ]; then
    rg "expert_dim" modelopt/torch/quantization/plugins/huggingface.py -n -C 2
else
    echo "File not found at expected location"
fi

Repository: NVIDIA/Model-Optimizer

Length of output: 436


🌐 Web query:

Qwen3VLMoeTextExperts expert_dim attribute huggingface transformers

💡 Result:

Short answer: Hugging Face's Qwen3/VL-MoE configs do not expose an attribute named expert_dim. For MoE in transformers you should use:

  • num_experts — number of experts. [1][2]
  • num_experts_per_tok (top‑k) — number of experts routed per token. [1][2]
  • moe_intermediate_size — the hidden (intermediate) size used inside each routed expert (what some codebases call an expert hidden-dimension). [1][3]

Evidence:

  • Qwen3MoE / Qwen3VLMoeConfig docs list num_experts, num_experts_per_tok and moe_intermediate_size but do not define expert_dim. [1][2][3]
  • Other libraries sometimes use the name expert_dim (e.g., MindSpore examples) to mean either the number of experts or the per‑expert dimension depending on their API — so naming differs between frameworks. [4][5]

If you expected an expert_dim field in transformers, use moe_intermediate_size for per-expert hidden size and num_experts for expert count. References: Hugging Face transformers Qwen3/Qwen3-VL-MoE docs and MindSpore MoE docs. [1][2][3][4]

Sources
[1] Hugging Face — Qwen3MoE config docs.
[2] Hugging Face — Qwen3-VL-Moe config docs.
[3] Hugging Face (alternate version) — Qwen3-VL-Moe parameters.
[4] MindSpore MoE implementation (expert_dim usage).
[5] MindSpore MoE (other release).


Remove undefined expert_dim attribute reference — this will cause AttributeError at runtime.

The expert_dim attribute does not exist on HuggingFace's Qwen3VLMoeTextExperts class. Lines 607–608 access self.expert_dim, which will fail at runtime. The correct attribute for the per-expert intermediate size is moe_intermediate_size. Replace:

self.gate_up_proj[idx, :, : self.expert_dim].T
self.gate_up_proj[idx, :, self.expert_dim :].T

with:

self.gate_up_proj[idx, :, : self.moe_intermediate_size].T
self.gate_up_proj[idx, :, self.moe_intermediate_size :].T
🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 574 - 616,
In _QuantQwen3VLMoeTextExperts._setup the code references a non-existent
attribute self.expert_dim when slicing self.gate_up_proj, which will raise
AttributeError; update the slices to use the correct per-expert size attribute
self.moe_intermediate_size (i.e. replace uses of self.expert_dim in the
gate_up_proj slicing with self.moe_intermediate_size) so the _copy_weight calls
for gate_proj and up_proj use the proper slice, leaving down_proj and attribute
reassignment (gate_up_proj → gate_proj, down_proj → down_proj) unchanged.

def forward(
self,
hidden_states: torch.Tensor,
routing_weights: torch.Tensor,
router_indices: torch.Tensor,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
next_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
assert expert_idx.numel() == 1, expert_idx
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx[0]])
current_state = hidden_states[token_idx]
gate = self.gate_proj[expert_idx](current_state)
up = self.up_proj[expert_idx](current_state)
gated_output = up * self.act_fn(gate)
out = self.down_proj[expert_idx](gated_output)
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Using tensor to index nn.ModuleList will fail.

When iterating over expert_hit (from nonzero()), each expert_idx is a 1D tensor of shape (1,), not a Python integer. nn.ModuleList.__getitem__ expects an integer index, so self.gate_proj[expert_idx] will raise a TypeError.

Convert expert_idx to an integer before indexing:

🐛 Proposed fix
-    for expert_idx in expert_hit[:]:
-        assert expert_idx.numel() == 1, expert_idx
+    for expert_idx in expert_hit:
+        expert_idx = expert_idx.item()
         with torch.no_grad():
-            _, token_idx = torch.where(expert_mask[expert_idx[0]])
+            _, token_idx = torch.where(expert_mask[expert_idx])
         current_state = hidden_states[token_idx]
         gate = self.gate_proj[expert_idx](current_state)
         up = self.up_proj[expert_idx](current_state)
         gated_output = up * self.act_fn(gate)
         out = self.down_proj[expert_idx](gated_output)
-        weighted_output = out * routing_weights[token_idx, expert_idx, None]
+        weighted_output = out * routing_weights[token_idx, expert_idx:expert_idx+1, None]
         next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))

Note: The routing_weights[token_idx, expert_idx, None] indexing at line 639 may also need adjustment after converting to int, depending on the expected tensor shape.

🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 630 - 640,
expert_idx is a 1-element tensor and is being used to index nn.ModuleList
(self.gate_proj/self.up_proj/self.down_proj) and routing_weights, which raises a
TypeError; convert expert_idx to a Python int (e.g., expert_idx_int =
int(expert_idx.item() or expert_idx[0].item()) ) before using it to index the
ModuleList entries and use that int for routing_weights indexing
(routing_weights[token_idx, expert_idx_int, None]) while keeping token_idx as a
tensor for per-token selection, then proceed with gate_proj/up_proj/down_proj
using the integer index and next_states.index_add_ as before.

next_states = next_states.view(batch_size, -1, self.hidden_size)

return next_states


class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
Expand Down Expand Up @@ -660,6 +732,27 @@ def top_k(self, value):
except ImportError:
pass

try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock

if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register(
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
)(_QuantSparseMoe)
except ImportError:
pass


try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
_QuantQwen3VLMoeTextExperts
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down
Loading