Skip to content

[BUG] Bamba-9B-v2 model fails with torch.compile when using SDPA #43550

@harshaljanjani

Description

@harshaljanjani

System Info

  • transformers version: 5.0.0.dev0
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Python version: 3.12.3
  • huggingface_hub version: 1.3.2
  • safetensors version: 0.7.0
  • accelerate version: 1.12.0
  • Accelerate config: not installed
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.9.1+cu128 (CUDA)
  • GPU type: NVIDIA L4
  • NVIDIA driver version: 550.90.07
  • CUDA version: 12.4

Who can help?

@Rocketknight1 @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import AutoModelForCausalLM

torch.compiler.reset()
model = AutoModelForCausalLM.from_pretrained(
    "ibm-ai-platform/Bamba-9B-v2",
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
    device_map="cuda"
)
model = torch.compile(model, dynamic=True)
input_ids = torch.tensor([[1, 2, 3, 4, 5]], device="cuda")
with torch.no_grad():
    output = model(input_ids)
print(output.logits.shape)

This pattern, which is followed by other models across the codebase such as Falcon and OPT, causes an SDPA compilation failure when applied in modeling_bamba.py. Fixing this issue should also resolve the latest failures in test_modeling_bamba.py.

Current Error:

Image

Current Reproduction Script Output (Local Environment):

Image

Expected behavior

The model should compile successfully with SDPA attention. Additionally, this fix should resolve the latest failures in test_modeling_bamba.py (NVIDIA CI).

Expected Reproduction Script Output After Applying the Fix:

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions