Skip to content

Commit 4017e8d

Browse files
committed
fix FP8 amax calculation
Signed-off-by: Fridah-nv <[email protected]>
1 parent 61004e3 commit 4017e8d

File tree

1 file changed

+4
-4
lines changed
  • modelopt/torch/quantization/calib

1 file changed

+4
-4
lines changed

modelopt/torch/quantization/calib/mse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def collect(self, x: torch.Tensor):
116116
device = x.device
117117

118118
if self._fp8_scale_sweep:
119+
global_amax = quant_utils.reduce_amax(x, axis=None, keepdims=False, squeeze_scalar=True)
120+
global_amax_expanded = global_amax * torch.ones_like(self._initial_amax)
121+
119122
# Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
120123
# Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
121124
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
@@ -125,7 +128,6 @@ def collect(self, x: torch.Tensor):
125128
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
126129
fp8_values_valid = fp8_values[valid_mask]
127130

128-
# Scale down by 448 to ensure the range is appropriate for FP8 quantization
129131
candidates = fp8_values_valid / 448.0
130132

131133
print(
@@ -146,10 +148,8 @@ def collect(self, x: torch.Tensor):
146148

147149
for step, candidate in enumerate(candidates):
148150
if self._fp8_scale_sweep:
149-
# For FP8 scale sweep, use FP8 values as multipliers of initial_amax
150-
# This ensures we search in a reasonable range relative to max calibration
151151
multiplier = candidate
152-
candidate_amax = self._initial_amax * multiplier
152+
candidate_amax = global_amax_expanded * multiplier
153153
else:
154154
# For normal MSE calibration, multiply initial amax by the multiplier
155155
candidate_amax = self._initial_amax * candidate

0 commit comments

Comments
 (0)