Skip to content

Conversation

@nyo16
Copy link
Contributor

@nyo16 nyo16 commented Jan 8, 2026

Summary

Add native FP8 quantized model support for models like Qwen3-FP8. This enables loading and running FP8 models with per-block
scale factors (scale_inv) for dequantization.

Changes

bumblebee.ex

  • Add :preserve_source_types option to load_model/2 to keep FP8 types during loading

pytorch_params.ex

  • Pass preserve_source_types through param loading pipeline
  • Modify ensure_type/3 to preserve FP8 types when option is set

layers.ex

  • Add fp8_aware_dense/3 layer that handles FP8 quantized weights
  • Implements block-wise dequantization using scale_inv parameter
  • Automatically falls back to identity scaling (1.0) for non-FP8 models

layers/transformer.ex

  • Add :attention_dense option to blocks/2, block/2, multi_head_attention/4
  • Allows custom dense function for Q, K, V, and output projections

text/qwen3.ex

  • Update decoder to use fp8_aware_dense for attention via attention_dense option
  • Update gated_ffn to use fp8_aware_dense for FFN layers
  • Add scale_inv to params_mapping for all attention and FFN layers

Test plan

  • FP8 model (Qwen3-0.6B-FP8) generates correct output ("Paris" for capital of France)
  • Non-FP8 model (Qwen3-0.6B) still works correctly (backward compatible)
  • Tested on RTX 5070 Ti (Blackwell, SM 12.0)

Dependencies

Requires (merge in order):

  1. elixir-nx/safetensors - FP8 file I/O
  2. elixir-nx/nx - FP8 type system support

Usage

# Load FP8 model with native weights
{:ok, model_info} = Bumblebee.load_model(
  {:hf, "Qwen/Qwen3-0.6B-FP8"},
  architecture: :for_causal_language_modeling,
  preserve_source_types: true
)

# Use normally - scale_inv dequantization happens automatically
serving = Bumblebee.Text.generation(model_info, tokenizer, generation_config)
Nx.Serving.run(serving, "The capital of France is")
# => "Paris..."

Add comprehensive FP8 quantized model support for models like Qwen3-FP8.
This enables loading and running FP8 models with per-block scale factors.

Changes:

bumblebee.ex:
- Add :preserve_source_types option to load_model/2 to keep FP8 types

pytorch_params.ex:
- Pass preserve_source_types through param loading pipeline
- Modify ensure_type/3 to preserve FP8 types when option is set

layers.ex:
- Add fp8_aware_dense/3 layer that handles FP8 quantized weights
- Implements block-wise dequantization using scale_inv parameter
- Automatically falls back to identity scaling for non-FP8 models

layers/transformer.ex:
- Add :attention_dense option to blocks/2, block/2, multi_head_attention/4
- Allows custom dense function for Q, K, V, and output projections

text/qwen3.ex:
- Update decoder to use fp8_aware_dense for attention via attention_dense
- Update gated_ffn to use fp8_aware_dense for FFN layers
- Add scale_inv to params_mapping for all attention and FFN layers

The implementation supports both:
- Pre-dequantization: Convert FP8->F32 before loading
- Native FP8: Load FP8 weights directly, apply scale_inv at runtime

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@nyo16 nyo16 marked this pull request as draft January 8, 2026 17:39
@josevalim
Copy link
Contributor

Thank you! This PR should probably wait until this is done: elixir-nx/nx#1657 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants