Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.torch.export import (
export_hf_checkpoint,
export_hf_vllm_fq_checkpoint,
export_tensorrt_llm_checkpoint,
get_model_type,
)
Expand Down Expand Up @@ -77,6 +77,9 @@
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_mse": mtq.NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
"nvfp4_lo_he": mtq.NVFP4_LOCAL_HESSIAN_CFG,
"nvfp4_gl_he": mtq.NVFP4_GLOBAL_HESSIAN_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
Expand Down Expand Up @@ -139,10 +142,10 @@ def make_calib_dataloader(
assert tokenizer is not None and isinstance(
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
), "The PreTrainedTokenizer must be set"
# Labels are only needed for gradient-based auto_quantize
# Labels are needed for gradient-based auto_quantize or global hessian calibration
include_labels = (
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
)
) or args.qformat == "nvfp4_gl_he" # Global hessian needs labels for backward pass
calib_dataloader = get_dataset_dataloader(
dataset_name=args.dataset,
tokenizer=tokenizer,
Expand Down Expand Up @@ -432,8 +435,18 @@ def mono_quantize(

if not use_calibration:
warnings.warn("Dynamic quantization. Calibration skipped.")

# Check if we need backward pass for global hessian calibration
algorithm_cfg = quant_cfg.get("algorithm", {})
use_global_hessian = (
algorithm_cfg.get("method") == "local_hessian"
and algorithm_cfg.get("hessian_type") == "global"
)

calibrate_loop = (
create_forward_loop(dataloader=calib_dataloader) if use_calibration else None
create_forward_loop(dataloader=calib_dataloader, enable_backward=use_global_hessian)
if use_calibration
else None
)

if calibration_only:
Expand Down Expand Up @@ -535,7 +548,7 @@ def export_quantized(
"They will be set at deployment time."
)

export_hf_checkpoint(
export_hf_vllm_fq_checkpoint(
full_model,
export_dir=export_path,
)
Expand Down
50 changes: 43 additions & 7 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,42 @@ def convert_amax_hf2vllm(
# Copy other amax keys as-is (like o_proj, down_proj)
vllm_state_dict[key] = value

# Merge grouped amax values by taking the maximum
# Merge grouped amax values
for merged_key, key_value_pairs in merge_groups.items():
if len(key_value_pairs) > 1:
# Take the maximum across all values for this merged key
values = [value for _, value in key_value_pairs]
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
print(f"Merged {len(key_value_pairs)} keys into {merged_key}")
for orig_key, _ in key_value_pairs:
print(f" - {orig_key}")
shapes = [v.shape for v in values]
is_weight_quantizer = "weight_quantizer" in merged_key

if is_weight_quantizer:
# Weight quantizers: always concatenate because vLLM fuses weights
# by concatenation (qkv_proj = concat(q, k, v), gate_up_proj = concat(gate, up))
merged_value = torch.cat(values, dim=0)
vllm_state_dict[merged_key] = merged_value
print(
f"Concatenated {len(key_value_pairs)} weight amax keys into {merged_key} "
f"(shapes {shapes} -> {merged_value.shape})"
)
for orig_key, _ in key_value_pairs:
print(f" - {orig_key}")
# Input quantizers: take max (they share the same input tensor)
elif all(s == shapes[0] for s in shapes):
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
print(f"Merged {len(key_value_pairs)} input amax keys into {merged_key}")
for orig_key, _ in key_value_pairs:
print(f" - {orig_key}")
else:
# Different shapes for input quantizers - this shouldn't happen normally
# but handle it gracefully by taking element-wise max after padding
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
print(
f"Warning: Input quantizer amax shapes differ {shapes}, "
f"taking max for {merged_key}"
)
for orig_key, _ in key_value_pairs:
print(f" - {orig_key}")
else:
# Single key, just rename it
_, value = key_value_pairs[0]
Expand Down Expand Up @@ -264,6 +290,16 @@ def calibrate_loop(model: Any = None) -> None:
{} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"])
)

# When loading from amax file, override algorithm to "max" since calibration was done offline.
amax_file_path = quant_config["amax_file_path"]
if amax_file_path and quant_cfg:
original_algorithm = quant_cfg.get("algorithm")
if isinstance(original_algorithm, dict) or original_algorithm not in ["max", None]:
print(
f"Overriding algorithm from {original_algorithm} to 'max' since loading from amax file"
)
quant_cfg = {**quant_cfg, "algorithm": "max"}

model = self.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()
Expand Down
130 changes: 130 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,47 @@
},
}


NVFP4_LOCAL_HESSIAN_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "local_hessian",
"hessian_type": "local",
"fp8_scale_sweep": True,
},
}

NVFP4_GLOBAL_HESSIAN_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "local_hessian",
"hessian_type": "global",
"fp8_scale_sweep": True,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -1104,6 +1145,95 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
)


class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for Hessian-weighted MSE calibration.

This algorithm uses activation information to optimize per-block scales for weight
quantization. It minimizes the output reconstruction error by weighting the loss
with the Hessian matrix computed from input activations (and optionally output gradients).

The Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
- ``H`` is the Hessian matrix (local or global, depending on ``hessian_type``)

Two Hessian types are supported:

- **local**: ``H = X @ X.T`` - uses only input activations. Faster, no backward pass needed.
- **global**: ``H = (X * grad²) @ X.T`` - weights by output gradient squared.
More accurate as it accounts for output importance, but requires backward pass.

This method is particularly effective for NVFP4 weight-only quantization where
activation information helps select better per-block scales.
"""

method: Literal["local_hessian"] = ModeloptField("local_hessian")

hessian_type: Literal["local", "global"] = ModeloptField(
default="local",
title="Type of Hessian to compute.",
description="""Type of Hessian matrix to use for weighting quantization errors:

- ``"local"``: H = X @ X.T - Only uses input activations. Fast, forward-pass only.
- ``"global"``: H = (X * grad²) @ X.T - Weights by output gradient squared.
More accurate as it captures output importance, but requires backward pass
during calibration.

The global Hessian is closer to the true Fisher Information and typically
gives better results, but at the cost of running backward passes.""",
)

step_size: float | None = ModeloptField(
default=0.1,
gt=0.0,
title="Step size for amax search.",
description="Step size between amax candidates. The number of candidates is computed as "
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
)

start_multiplier: float | None = ModeloptField(
default=0.25,
gt=0.0,
title="Starting multiplier for amax search.",
description="Starting multiplier for amax search range (multiplies initial amax).",
)

stop_multiplier: float | None = ModeloptField(
default=4.0,
gt=0.0,
title="Ending multiplier for amax search.",
description="Ending multiplier for amax search range (multiplies initial amax).",
)

fp8_scale_sweep: bool | None = ModeloptField(
default=True,
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
"for NVFP4 per-block quantization instead of using multipliers. "
"This is the recommended setting for NVFP4 quantization.",
)

block_size: int | None = ModeloptField(
default=16,
gt=0,
title="Block size for Hessian computation.",
description="The block size used for computing the Hessian matrix. "
"This should match the block size used in the quantization config. "
"Default is 16 for NVFP4.",
)

distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
description="If True, the amax will be synced across the distributed processes.",
)

debug: bool | None = ModeloptField(
default=False,
title="Debug mode.",
description="If True, module's Hessian metadata will be kept as a module attribute.",
)


class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
"""The config for ``smoothquant`` algorithm (SmoothQuant).

Expand Down
26 changes: 25 additions & 1 deletion modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AWQFullCalibConfig,
AWQLiteCalibConfig,
CompressConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
MseCalibConfig,
QuantizeAlgoCfgType,
Expand All @@ -55,7 +56,14 @@
restore_svdquant_model,
update_quantize_metadata,
)
from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant
from .model_calib import (
awq,
local_hessian_calibrate,
max_calibrate,
mse_calibrate,
smoothquant,
svdquant,
)

__all__ = ["BaseCalibrateModeDescriptor"]

Expand Down Expand Up @@ -376,6 +384,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
_calib_func = mse_calibrate


@CalibrateModeRegistry.register_mode
class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for local Hessian-weighted MSE calibration algorithm.

This algorithm uses activation information to optimize per-block scales for weight
quantization by minimizing output reconstruction error instead of weight reconstruction error.
"""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return LocalHessianCalibConfig

_calib_func = local_hessian_calibrate


@CalibrateModeRegistry.register_mode
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for smoothquant calibration algorithm."""
Expand Down
Loading