-
Notifications
You must be signed in to change notification settings - Fork 267
Add Quantizers for Qwen3VLMoeTextDecoderLayer #666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
d3013f6
70ddcc4
4ac0c87
904be6a
ef200ea
60b9e75
37c24f4
41efac8
ac6d76d
1c02c90
5ad1f7a
f4c5de5
6fab4a4
b4919e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -571,6 +571,78 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: | |
| return self.w2_linear[expert_idx](x1) | ||
|
|
||
|
|
||
| class _QuantQwen3VLMoeTextExperts(QuantModule): | ||
| def _setup(self): | ||
| """Modify the Qwen3VLMoeTextExperts by using nn.Linear layers.""" | ||
| from accelerate import init_empty_weights | ||
|
|
||
| dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device | ||
|
|
||
| def _copy_weight(module, weight): | ||
| module.to_empty(device=device) | ||
| with torch.no_grad(): | ||
| module.weight.data = weight.detach().data.to(dtype=dtype, device=device) | ||
|
|
||
| with init_empty_weights(): | ||
| gate_proj = nn.ModuleList( | ||
| [ | ||
| nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
| for _ in range(self.num_experts) | ||
| ] | ||
| ) | ||
| up_proj = nn.ModuleList( | ||
| [ | ||
| nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | ||
| for _ in range(self.num_experts) | ||
| ] | ||
| ) | ||
| down_proj = nn.ModuleList( | ||
| [ | ||
| nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | ||
| for _ in range(self.num_experts) | ||
| ] | ||
| ) | ||
|
|
||
| for idx in range(self.num_experts): | ||
| _copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, : self.expert_dim].T) | ||
| _copy_weight(up_proj[idx], self.gate_up_proj[idx, :, self.expert_dim :].T) | ||
| _copy_weight(down_proj[idx], self.down_proj[idx, :].T) | ||
|
|
||
| delattr(self, "gate_up_proj") | ||
| delattr(self, "down_proj") | ||
| self.gate_proj = gate_proj | ||
| self.up_proj = up_proj | ||
| self.down_proj = down_proj | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| router_indices: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| batch_size = hidden_states.shape[0] | ||
| hidden_states = hidden_states.reshape(-1, self.hidden_size) | ||
| next_states = torch.zeros_like(hidden_states) | ||
| with torch.no_grad(): | ||
| expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts) | ||
| expert_mask = expert_mask.permute(2, 1, 0) | ||
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() | ||
| for expert_idx in expert_hit[:]: | ||
soodoshll marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
soodoshll marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert expert_idx.numel() == 1, expert_idx | ||
soodoshll marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
soodoshll marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with torch.no_grad(): | ||
| _, token_idx = torch.where(expert_mask[expert_idx[0]]) | ||
| current_state = hidden_states[token_idx] | ||
| gate = self.gate_proj[expert_idx](current_state) | ||
| up = self.up_proj[expert_idx](current_state) | ||
| gated_output = up * self.act_fn(gate) | ||
| out = self.down_proj[expert_idx](gated_output) | ||
| weighted_output = out * routing_weights[token_idx, expert_idx, None] | ||
| next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) | ||
|
||
| next_states = next_states.view(batch_size, -1, self.hidden_size) | ||
|
|
||
| return next_states | ||
|
|
||
|
|
||
| class _QuantDbrxFFN(_QuantSparseMoe): | ||
| @property | ||
| def num_experts(self): | ||
|
|
@@ -660,6 +732,27 @@ def top_k(self, value): | |
| except ImportError: | ||
| pass | ||
|
|
||
| try: | ||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextSparseMoeBlock | ||
|
|
||
| if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: | ||
| QuantModuleRegistry.register( | ||
| {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} | ||
| )(_QuantSparseMoe) | ||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| try: | ||
| from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts | ||
|
|
||
| if Qwen3VLMoeTextExperts not in QuantModuleRegistry: | ||
| QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( | ||
| _QuantQwen3VLMoeTextExperts | ||
| ) | ||
soodoshll marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| class _QuantGptOssExperts(_QuantFunctionalMixin): | ||
| """Quantized wrapper for `transformers.GptOssExperts`. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 152
🏁 Script executed:
Repository: NVIDIA/Model-Optimizer
Length of output: 436
🌐 Web query:
Qwen3VLMoeTextExperts expert_dim attribute huggingface transformers💡 Result:
Short answer: Hugging Face's Qwen3/VL-MoE configs do not expose an attribute named expert_dim. For MoE in transformers you should use:
Evidence:
If you expected an expert_dim field in transformers, use moe_intermediate_size for per-expert hidden size and num_experts for expert count. References: Hugging Face transformers Qwen3/Qwen3-VL-MoE docs and MindSpore MoE docs. [1][2][3][4]
Sources
[1] Hugging Face — Qwen3MoE config docs.
[2] Hugging Face — Qwen3-VL-Moe config docs.
[3] Hugging Face (alternate version) — Qwen3-VL-Moe parameters.
[4] MindSpore MoE implementation (expert_dim usage).
[5] MindSpore MoE (other release).
Remove undefined
expert_dimattribute reference — this will causeAttributeErrorat runtime.The
expert_dimattribute does not exist on HuggingFace'sQwen3VLMoeTextExpertsclass. Lines 607–608 accessself.expert_dim, which will fail at runtime. The correct attribute for the per-expert intermediate size ismoe_intermediate_size. Replace:with:
🤖 Prompt for AI Agents