Skip to content

Conversation

@Fridah-nv
Copy link
Contributor

@Fridah-nv Fridah-nv commented Jan 9, 2026

What does this PR do?

Type of change: ?
new feature

Overview: ?
Adds fp8_scale_sweep mode to MSE calibrator for optimizing FP8-quantized per-block scales in NVFP4 format.

Usage

Tested with this config

NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
    "quant_cfg": {
        "*weight_quantizer": {
            "num_bits": (2, 1),
            "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
            "axis": None,
            "enable": True,
        },
        "*input_quantizer": {
            "enable": False,
        },
        **_default_disabled_quantizer_cfg,
    },
    "algorithm": {
        "method": "mse",
        "fp8_scale_sweep": True,
    },
}

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

Release Notes

New Features

  • Added FP8 scale sweep option for quantization calibration, enabling optimized scale value sweeping for NVFP4 per-block quantization.
  • Introduced new NVFP4_WEIGHT_MSE_CFG configuration preset for improved weight quantization workflows.

Tests

  • Added test coverage validating FP8 scale sweep functionality and reset behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

@Fridah-nv Fridah-nv requested a review from a team as a code owner January 9, 2026 22:45
@Fridah-nv Fridah-nv requested review from jingyu-ml and removed request for a team January 9, 2026 22:45
@Fridah-nv Fridah-nv marked this pull request as draft January 9, 2026 22:45
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 9, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Base automatically changed from fridah/block-mse to main January 13, 2026 22:43
@Fridah-nv Fridah-nv force-pushed the fridah/mse-fp8-sweep branch from 4017e8d to 3c360cb Compare January 15, 2026 19:33
@Fridah-nv Fridah-nv force-pushed the fridah/mse-fp8-sweep branch from 3c360cb to 51f5e86 Compare January 15, 2026 19:45
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

📝 Walkthrough

Walkthrough

This pull request introduces FP8 scale sweep functionality to the MSE calibrator. The feature enables sweeping over FP8 E4M3 scale values instead of traditional multiplier steps when fp8_scale_sweep is enabled. The implementation spans configuration updates, calibrator modifications, and corresponding test coverage.

Changes

Cohort / File(s) Summary
MSE Calibrator Core Logic
modelopt/torch/quantization/calib/mse.py
Adds fp8_scale_sweep parameter to MseCalibrator.__init__. Implements FP8 sweep path in collect(): computes global amax, generates 128 FP8 E4M3 values, filters to valid finite positive candidates, normalizes to multipliers, and iterates through candidates computing losses with expanded global amax instead of initial amax. Maintains per-tensor/per-channel behavior when disabled.
Quantization Configuration
modelopt/torch/quantization/config.py
Adds fp8_scale_sweep boolean field to MseCalibConfig with documentation that it is used for NVFP4 per-block quantizers. Introduces new NVFP4_WEIGHT_MSE_CFG configuration constant for per-block NVFP4 quantization with weight-only focus.
Model Calibration Integration
modelopt/torch/quantization/model_calib.py
Adds fp8_scale_sweep parameter to mse_calibrate() function. Derives local is_nvfp4_per_block flag to propagate FP8 sweep behavior to MseCalibrator construction and quant_func usage.
Test Configuration Updates
tests/gpu/torch/quantization/test_quantize_cuda.py
Updates parameterized test configurations to use mtq.NVFP4_WEIGHT_ACT_MSE_CFG instead of local reference and adds new mtq.NVFP4_WEIGHT_MSE_CFG configuration to test parameter lists.
FP8 Scale Sweep Test Coverage
tests/unit/torch/quantization/test_mse_calibrator.py
Adds new test method test_fp8_scale_sweep_with_fixed_values_and_reset that validates FP8 scale sweep behavior including step count, amax bounds, and reproducibility after reset operations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'add FP8 sweep option for static NVFP4 MSE' accurately reflects the main feature addition: enabling FP8 scale sweep for NVFP4 per-block MSE calibration.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Fridah-nv Fridah-nv self-assigned this Jan 15, 2026
@Fridah-nv Fridah-nv marked this pull request as ready for review January 15, 2026 19:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 113-118: The two debug print() calls in the FP8 scale sweep code
(the lines referencing fp8_scale_sweep and candidates, printing "FP8 scale
sweep: trying..." and "Multiplier range: ...") should not unconditionally print
during calibration; either remove them or replace them with a logger call (e.g.,
logger.debug/ info) or guard them behind a verbose flag passed to the
calibration routine (add a verbose parameter to the function that runs the FP8
sweep and wrap the prints with if verbose:). Locate the prints by searching for
the variables fp8_scale_sweep and candidates in mse.py and replace the print()
calls with a logging call or conditional using the verbose parameter so output
is controlled by log level or caller-provided verbosity.

In `@modelopt/torch/quantization/config.py`:
- Around line 414-434: The review points out that the new config constants
NVFP4_WEIGHT_MSE_CFG and NVFP4_WEIGHT_ACT_MSE_CFG are not registered in the
module's choices set; update the module-level choices set (the set named choices
that lists supported quantization format names) to include the strings
"NVFP4_WEIGHT_MSE" and "NVFP4_WEIGHT_ACT_MSE" alongside the other format names
so these configs are discoverable via the public API; ensure the added names
exactly match the naming convention used for other entries in choices.

In `@modelopt/torch/quantization/model_calib.py`:
- Around line 268-273: The current is_nvfp4_per_block condition may raise
AttributeError when module._block_sizes is None; update the boolean expression
in the is_nvfp4_per_block computation to explicitly ensure module._block_sizes
is not None (or use getattr to obtain a dict) before calling .get("scale_bits"),
e.g., include a check like module._block_sizes is not None in the chain so the
.get call is only executed when _block_sizes exists; refer to the
is_nvfp4_per_block variable, fp8_scale_sweep, module.is_static_block_quant and
module._block_sizes to locate and modify the condition.
🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_mse_calibrator.py (1)

530-594: Good test coverage for FP8 scale sweep functionality.

The test comprehensively validates:

  • Step count (126 valid FP8 E4M3 values)
  • Optimal amax bounds
  • Reset behavior
  • Reproducibility after reset

One observation: this test uses per-tensor FP8 quantization (axis=None, no block_sizes), but the FP8 scale sweep feature is documented as being "specifically for NVFP4 per-block quantization." Consider adding a separate test case that more closely mirrors the NVFP4 per-block configuration with block_sizes to ensure the feature works correctly in its intended context.

tests/gpu/torch/quantization/test_quantize_cuda.py (1)

31-47: Remove unused local NVFP4_WEIGHT_ACT_MSE_CFG definition.

The local configuration defined at lines 31-47 is not referenced anywhere in the file. The test uses mtq.NVFP4_WEIGHT_ACT_MSE_CFG instead. Remove the local definition to eliminate dead code.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 849a350 and 51f5e86.

📒 Files selected for processing (5)
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_quantize_cuda.py
  • tests/unit/torch/quantization/test_mse_calibrator.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_calib.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_static_block_quant (518-524)
modelopt/torch/quantization/config.py (1)
modelopt/torch/opt/config.py (1)
  • ModeloptField (50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
tests/gpu/torch/quantization/test_quantize_cuda.py (1)

73-74: LGTM!

The new configurations are correctly added to both the test parameterization and the skip condition list, ensuring proper test coverage while respecting the cuda_ext_mx requirement and Conv layer limitations.

Also applies to: 91-92

modelopt/torch/quantization/config.py (1)

1093-1101: LGTM! The fp8_scale_sweep field is well-documented.

The field description clearly explains:

  • The purpose (sweeping FP8 E4M3 scale values for NVFP4 per-block quantization)
  • When parameters are ignored (num_steps, step_size, start_multiplier, stop_multiplier)
  • The specific use case (FP8-quantized per-block scales in NVFP4 format)
modelopt/torch/quantization/model_calib.py (1)

208-208: LGTM!

The new fp8_scale_sweep parameter is:

  • Backward compatible with default False
  • Well-documented in the docstring
  • Correctly propagated to the MseCalibrator

Also applies to: 224-227

modelopt/torch/quantization/calib/mse.py (4)

72-78: LGTM!

The initialization correctly sets up 126 steps for FP8 scale sweep, corresponding to all valid positive finite FP8 E4M3 values (excluding zero and NaN).


126-130: LGTM!

The branching logic correctly uses global_amax_expanded for FP8 scale sweep (adapting to current data) and initial_amax for traditional multiplier-based sweep.


140-141: Potential inconsistency with multiple batches in FP8 scale sweep mode.

When collect() is called multiple times with fp8_scale_sweep=True, each batch computes its own global_amax (line 99), but candidate_amaxs[step] is only set on the first call (line 140-141 guard). This means:

  • Later batches use candidate_amax values computed from the first batch's global_amax
  • But losses are computed using each batch's own global_amax * candidate

This could lead to inconsistent results if batches have significantly different data ranges. Verify this is the intended behavior for the FP8 scale sweep use case.


98-111: The concern about PyTorch version compatibility is not applicable. The codebase already requires torch>=2.6 (specified in setup.py), which is well above the actual minimum version needed for torch.float8_e4m3fn (PyTorch 2.2+). The code at line 105 can safely use the dtype without version guards.

Likely an incorrect or invalid review comment.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
@codecov
Copy link

codecov bot commented Jan 15, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.37%. Comparing base (849a350) to head (75e3ccd).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #758      +/-   ##
==========================================
+ Coverage   74.23%   74.37%   +0.13%     
==========================================
  Files         192      192              
  Lines       19033    19051      +18     
==========================================
+ Hits        14129    14169      +40     
+ Misses       4904     4882      -22     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fridah-nv Fridah-nv changed the title add FP8 sweep and step_size flag add FP8 sweep option for static NVFP4 MSE Jan 15, 2026
Comment on lines +414 to +432
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not add this config now - lets wait till we get compelling results before adding new configs. configs are ModelOpt's official recommendations.

In addition, this recipe is not deployment supported (inputs are not quantized. weight only NVFP4 is not supported yet)

"This is specifically designed for optimizing the FP8-quantized per-block scales "
"in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and "
"stop_multiplier are ignored for NVFP4 per-block quantizers.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
This is ignore for all quantizations except NVFP4 weight quantization.)


return xq

is_nvfp4_per_block = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
is_nvfp4_per_block = (
is_nvfp4_static = (
"weight_quantizer" in name and


if self._fp8_scale_sweep:
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
this should not be needed (since global_amax is a scalar)

Suggested change
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)

candidate_amax = self._initial_amax * multiplier
for step, candidate in enumerate(candidates):
if self._fp8_scale_sweep:
candidate_amax = global_amax_expanded * candidate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the /6.0 ?
Should not this be

Suggested change
candidate_amax = global_amax_expanded * candidate
candidate_amax = (global_amax/6.0) * candidate_by_448

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am understanding that this is handled somewhere else. is that correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found it -

outputs = static_blockwise_fp4_fake_quant(

should we support static_blockwise_fp4_fake_quant to provide either the scale or the amax?

then we dont need to handle the amax/6.0 detail in tensorquantizer

quantizer._amax = original_amax
else:
delattr(quantizer, "_amax")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we move the quant_func outside of the for loop and do:

module._calibrator = MseCalibrator(
                    amax=initial_amax,
                    axis=module._calibrator._axis,
                    step_size=step_size,
                    start_multiplier=start_multiplier,
                    stop_multiplier=stop_multiplier,
                    quant_func= partial(quant_func, module=quantizer),
                    fp8_scale_sweep=is_nvfp4_per_block,
                )

Copy link
Contributor

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall, could you please address my comments?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants