Skip to content

Conversation

@RunningLeon
Copy link
Collaborator

@RunningLeon RunningLeon commented Nov 27, 2025

Motivation

Support fp32 head for qwen and internlm models

Modification

from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig

if __name__ == '__main__':
    backend_config = PytorchEngineConfig(hf_overrides=dict(enforce_fp32_head=True))
    model_path = 'Qwen/Qwen3-30B-A3B'
    pipe = pipeline(model_path, backend_config=backend_config)

    resps = pipe(['Hi.'])
    for res in resps:
        print(res)

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@RunningLeon RunningLeon requested a review from grimoire November 27, 2025 13:07
@RunningLeon RunningLeon marked this pull request as ready for review January 21, 2026 11:53
Copilot AI review requested due to automatic review settings January 21, 2026 11:53
@RunningLeon RunningLeon changed the title [WIP]: Support fp32 head for qwen and internlm models Support fp32 head for qwen and internlm models Jan 21, 2026
@RunningLeon RunningLeon requested a review from lvhan028 January 21, 2026 11:53
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for FP32 precision language model heads for Qwen and InternLM model families. The feature allows the embedding and lm_head layers to compute in FP32 for improved numerical stability while the rest of the model runs in lower precision (e.g., FP16/BF16). This is enabled through a new enforce_fp32_head configuration option passed via hf_overrides.

Changes:

  • Added DeployModelMixinV1 base class with FP32 head support
  • Enhanced ParallelEmbedding with force_dtype parameter for FP32 weight storage
  • Refactored 12+ model classes to use the new mixin and build methods

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
lmdeploy/pytorch/nn/embedding.py Added force_dtype parameter to ParallelEmbedding for fp32 embeddings with dtype conversion on output
lmdeploy/pytorch/models/utils/model.py Introduced DeployModelMixinV1 with build_lm_head() and get_logits() methods supporting fp32 head
lmdeploy/pytorch/config.py Added config handling to extract and propagate enforce_fp32_head from hf_overrides
lmdeploy/pytorch/models/qwen*.py Refactored Qwen models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support
lmdeploy/pytorch/models/internlm*.py Refactored InternLM models to use DeployModelMixinV1 and ParallelEmbedding with fp32 support
lmdeploy/pytorch/models/internvl*.py Updated InternVL models to use DeployModelMixinV1 and delegate get_lm_head to language_model
lmdeploy/pytorch/models/phi3*.py Updated Phi3 models to use DeployModelMixinV1
lmdeploy/pytorch/models/qwen*_vl.py Updated Qwen VL models to use DeployModelMixinV1 and ParallelEmbedding
lmdeploy/pytorch/models/gpt_oss.py Updated GPT OSS model to use DeployModelMixinV1 and ParallelEmbedding
Comments suppressed due to low confidence (1)

lmdeploy/pytorch/models/internlm2.py:318

  • The InternLM2ForCausalLM class uses self.output as the name for its language model head, but the parent class DeployModelMixinV1.get_lm_head() expects the attribute to be named self.lm_head. This mismatch will cause an AttributeError when the inherited get_logits method tries to access self.get_lm_head().weight.dtype.

You need to either:

  1. Override get_lm_head() in InternLM2ForCausalLM to return self.output, or
  2. Keep the existing get_logits() override and update it to match the fp32 head behavior from DeployModelMixinV1
class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
    """Rewrote model of InternLM2ForCausalLM."""

    packed_modules_mapping = {
        'gate_up_proj': [
            'w1',
            'w3',
        ],
    }

    def __init__(self,
                 config: PretrainedConfig,
                 ctx_mgr: StepContextManager,
                 dtype: torch.dtype = None,
                 device: torch.device = None):
        super().__init__()
        self.config = config
        self.ctx_mgr = ctx_mgr
        # build Model
        self.model = InternLM2Model(config, dtype=dtype, device=device)
        # build lm_head
        self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        past_key_values: List[List[torch.Tensor]],
        attn_metadata: Any = None,
        inputs_embeds: torch.Tensor = None,
        **kwargs,
    ):
        """Model forward, return logits."""
        hidden_states = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=past_key_values,
            attn_metadata=attn_metadata,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def get_logits(self, hidden_states: torch.Tensor):
        """Compute logits of the model output."""
        return self.output(hidden_states)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None)
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None)

Copilot uses AI. Check for mistakes.
device: Optional[torch.device] = None,
**kwargs):
"""Build LM Head."""
head_dtype = torch.float32 if getattr(self.config, 'enforce_fp32_head') else dtype
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(self.config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(self.config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
head_dtype = torch.float32 if getattr(self.config, 'enforce_fp32_head') else dtype
head_dtype = torch.float32 if getattr(self.config, 'enforce_fp32_head', False) else dtype

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
config.pad_token_id,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +62
head_dtype = self.get_lm_head().weight.dtype
if hidden_states.dtype != head_dtype:
hidden_states = hidden_states.to(dtype=head_dtype)
hidden_states = self.get_lm_head()(hidden_states)
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_logits method calls self.get_lm_head() twice (lines 59 and 62). While this works correctly, it's inefficient as it involves two method calls and could potentially cause issues if get_lm_head() has side effects. Consider storing the result in a local variable to avoid the duplicate call.

Suggested change
head_dtype = self.get_lm_head().weight.dtype
if hidden_states.dtype != head_dtype:
hidden_states = hidden_states.to(dtype=head_dtype)
hidden_states = self.get_lm_head()(hidden_states)
lm_head = self.get_lm_head()
head_dtype = lm_head.weight.dtype
if hidden_states.dtype != head_dtype:
hidden_states = hidden_states.to(dtype=head_dtype)
hidden_states = lm_head(hidden_states)

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
self.padding_idx,
dtype=dtype,
device=device,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
Copy link

Copilot AI Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of getattr(config, 'enforce_fp32_head') without a default value will raise an AttributeError if the attribute doesn't exist. This can happen if the config was created before the enforce_fp32_head attribute was set, or if the configuration flow is bypassed.

Use getattr(config, 'enforce_fp32_head', False) instead to provide a safe default value.

Suggested change
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head') else None,
force_dtype=torch.float32 if getattr(config, 'enforce_fp32_head', False) else None,

Copilot uses AI. Check for mistakes.
@grimoire
Copy link
Collaborator

Should we put force_fp32_head in BuildModelContext so we don't have to read and set args in every model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants