-
Notifications
You must be signed in to change notification settings - Fork 239
add FP8 sweep option for static NVFP4 MSE #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
4017e8d to
3c360cb
Compare
Signed-off-by: Fridah-nv <[email protected]>
3c360cb to
51f5e86
Compare
📝 WalkthroughWalkthroughThis pull request introduces FP8 scale sweep functionality to the MSE calibrator. The feature enables sweeping over FP8 E4M3 scale values instead of traditional multiplier steps when Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 113-118: The two debug print() calls in the FP8 scale sweep code
(the lines referencing fp8_scale_sweep and candidates, printing "FP8 scale
sweep: trying..." and "Multiplier range: ...") should not unconditionally print
during calibration; either remove them or replace them with a logger call (e.g.,
logger.debug/ info) or guard them behind a verbose flag passed to the
calibration routine (add a verbose parameter to the function that runs the FP8
sweep and wrap the prints with if verbose:). Locate the prints by searching for
the variables fp8_scale_sweep and candidates in mse.py and replace the print()
calls with a logging call or conditional using the verbose parameter so output
is controlled by log level or caller-provided verbosity.
In `@modelopt/torch/quantization/config.py`:
- Around line 414-434: The review points out that the new config constants
NVFP4_WEIGHT_MSE_CFG and NVFP4_WEIGHT_ACT_MSE_CFG are not registered in the
module's choices set; update the module-level choices set (the set named choices
that lists supported quantization format names) to include the strings
"NVFP4_WEIGHT_MSE" and "NVFP4_WEIGHT_ACT_MSE" alongside the other format names
so these configs are discoverable via the public API; ensure the added names
exactly match the naming convention used for other entries in choices.
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 268-273: The current is_nvfp4_per_block condition may raise
AttributeError when module._block_sizes is None; update the boolean expression
in the is_nvfp4_per_block computation to explicitly ensure module._block_sizes
is not None (or use getattr to obtain a dict) before calling .get("scale_bits"),
e.g., include a check like module._block_sizes is not None in the chain so the
.get call is only executed when _block_sizes exists; refer to the
is_nvfp4_per_block variable, fp8_scale_sweep, module.is_static_block_quant and
module._block_sizes to locate and modify the condition.
🧹 Nitpick comments (2)
tests/unit/torch/quantization/test_mse_calibrator.py (1)
530-594: Good test coverage for FP8 scale sweep functionality.The test comprehensively validates:
- Step count (126 valid FP8 E4M3 values)
- Optimal amax bounds
- Reset behavior
- Reproducibility after reset
One observation: this test uses per-tensor FP8 quantization (
axis=None, noblock_sizes), but the FP8 scale sweep feature is documented as being "specifically for NVFP4 per-block quantization." Consider adding a separate test case that more closely mirrors the NVFP4 per-block configuration withblock_sizesto ensure the feature works correctly in its intended context.tests/gpu/torch/quantization/test_quantize_cuda.py (1)
31-47: Remove unused localNVFP4_WEIGHT_ACT_MSE_CFGdefinition.The local configuration defined at lines 31-47 is not referenced anywhere in the file. The test uses
mtq.NVFP4_WEIGHT_ACT_MSE_CFGinstead. Remove the local definition to eliminate dead code.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_quantize_cuda.pytests/unit/torch/quantization/test_mse_calibrator.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/model_calib.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_static_block_quant(518-524)
modelopt/torch/quantization/config.py (1)
modelopt/torch/opt/config.py (1)
ModeloptField(50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (7)
tests/gpu/torch/quantization/test_quantize_cuda.py (1)
73-74: LGTM!The new configurations are correctly added to both the test parameterization and the skip condition list, ensuring proper test coverage while respecting the
cuda_ext_mxrequirement and Conv layer limitations.Also applies to: 91-92
modelopt/torch/quantization/config.py (1)
1093-1101: LGTM! The fp8_scale_sweep field is well-documented.The field description clearly explains:
- The purpose (sweeping FP8 E4M3 scale values for NVFP4 per-block quantization)
- When parameters are ignored (
num_steps,step_size,start_multiplier,stop_multiplier)- The specific use case (FP8-quantized per-block scales in NVFP4 format)
modelopt/torch/quantization/model_calib.py (1)
208-208: LGTM!The new
fp8_scale_sweepparameter is:
- Backward compatible with default
False- Well-documented in the docstring
- Correctly propagated to the
MseCalibratorAlso applies to: 224-227
modelopt/torch/quantization/calib/mse.py (4)
72-78: LGTM!The initialization correctly sets up 126 steps for FP8 scale sweep, corresponding to all valid positive finite FP8 E4M3 values (excluding zero and NaN).
126-130: LGTM!The branching logic correctly uses
global_amax_expandedfor FP8 scale sweep (adapting to current data) andinitial_amaxfor traditional multiplier-based sweep.
140-141: Potential inconsistency with multiple batches in FP8 scale sweep mode.When
collect()is called multiple times withfp8_scale_sweep=True, each batch computes its ownglobal_amax(line 99), butcandidate_amaxs[step]is only set on the first call (line 140-141 guard). This means:
- Later batches use
candidate_amaxvalues computed from the first batch'sglobal_amax- But losses are computed using each batch's own
global_amax * candidateThis could lead to inconsistent results if batches have significantly different data ranges. Verify this is the intended behavior for the FP8 scale sweep use case.
98-111: The concern about PyTorch version compatibility is not applicable. The codebase already requirestorch>=2.6(specified in setup.py), which is well above the actual minimum version needed fortorch.float8_e4m3fn(PyTorch 2.2+). The code at line 105 can safely use the dtype without version guards.Likely an incorrect or invalid review comment.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
Signed-off-by: Fridah-nv <[email protected]>
Signed-off-by: Fridah-nv <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #758 +/- ##
==========================================
+ Coverage 74.23% 74.37% +0.13%
==========================================
Files 192 192
Lines 19033 19051 +18
==========================================
+ Hits 14129 14169 +40
+ Misses 4904 4882 -22 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { | ||
| "quant_cfg": { | ||
| "*weight_quantizer": { | ||
| "num_bits": (2, 1), | ||
| "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, | ||
| "axis": None, | ||
| "enable": True, | ||
| }, | ||
| "*input_quantizer": { | ||
| "enable": False, | ||
| }, | ||
| **_default_disabled_quantizer_cfg, | ||
| }, | ||
| "algorithm": { | ||
| "method": "mse", | ||
| "fp8_scale_sweep": True, | ||
| }, | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets not add this config now - lets wait till we get compelling results before adding new configs. configs are ModelOpt's official recommendations.
In addition, this recipe is not deployment supported (inputs are not quantized. weight only NVFP4 is not supported yet)
| "This is specifically designed for optimizing the FP8-quantized per-block scales " | ||
| "in NVFP4 format. When enabled, num_steps, step_size, start_multiplier, and " | ||
| "stop_multiplier are ignored for NVFP4 per-block quantizers.", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) | |
| This is ignore for all quantizations except NVFP4 weight quantization.) |
|
|
||
| return xq | ||
|
|
||
| is_nvfp4_per_block = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| is_nvfp4_per_block = ( | |
| is_nvfp4_static = ( | |
| "weight_quantizer" in name and |
|
|
||
| if self._fp8_scale_sweep: | ||
| global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True) | ||
| global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
this should not be needed (since global_amax is a scalar)
| global_amax_expanded = global_amax * torch.ones_like(self._initial_amax) | |
| candidate_amax = self._initial_amax * multiplier | ||
| for step, candidate in enumerate(candidates): | ||
| if self._fp8_scale_sweep: | ||
| candidate_amax = global_amax_expanded * candidate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the /6.0 ?
Should not this be
| candidate_amax = global_amax_expanded * candidate | |
| candidate_amax = (global_amax/6.0) * candidate_by_448 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am understanding that this is handled somewhere else. is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
found it -
| outputs = static_blockwise_fp4_fake_quant( |
should we support static_blockwise_fp4_fake_quant to provide either the scale or the amax?
then we dont need to handle the amax/6.0 detail in tensorquantizer
| quantizer._amax = original_amax | ||
| else: | ||
| delattr(quantizer, "_amax") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we move the quant_func outside of the for loop and do:
module._calibrator = MseCalibrator(
amax=initial_amax,
axis=module._calibrator._axis,
step_size=step_size,
start_multiplier=start_multiplier,
stop_multiplier=stop_multiplier,
quant_func= partial(quant_func, module=quantizer),
fp8_scale_sweep=is_nvfp4_per_block,
)
realAsma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall, could you please address my comments?
What does this PR do?
Type of change: ?
new feature
Overview: ?
Adds fp8_scale_sweep mode to MSE calibrator for optimizing FP8-quantized per-block scales in NVFP4 format.
Usage
Tested with this config
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.