@@ -116,6 +116,9 @@ def collect(self, x: torch.Tensor):
116116 device = x .device
117117
118118 if self ._fp8_scale_sweep :
119+ global_amax = quant_utils .reduce_amax (x , axis = None , keepdims = False , squeeze_scalar = True )
120+ global_amax_expanded = global_amax * torch .ones_like (self ._initial_amax )
121+
119122 # Generate all 128 possible FP8 E4M3 values (0-127 as uint8, viewed as float8_e4m3fn)
120123 # Create uint8 tensor with values 0-127, view as float8_e4m3fn, then convert to float32
121124 uint8_values = torch .arange (0 , 128 , dtype = torch .uint8 , device = device )
@@ -125,7 +128,6 @@ def collect(self, x: torch.Tensor):
125128 valid_mask = torch .isfinite (fp8_values ) & (fp8_values > 0 )
126129 fp8_values_valid = fp8_values [valid_mask ]
127130
128- # Scale down by 448 to ensure the range is appropriate for FP8 quantization
129131 candidates = fp8_values_valid / 448.0
130132
131133 print (
@@ -146,10 +148,8 @@ def collect(self, x: torch.Tensor):
146148
147149 for step , candidate in enumerate (candidates ):
148150 if self ._fp8_scale_sweep :
149- # For FP8 scale sweep, use FP8 values as multipliers of initial_amax
150- # This ensures we search in a reasonable range relative to max calibration
151151 multiplier = candidate
152- candidate_amax = self . _initial_amax * multiplier
152+ candidate_amax = global_amax_expanded * multiplier
153153 else :
154154 # For normal MSE calibration, multiply initial amax by the multiplier
155155 candidate_amax = self ._initial_amax * candidate
0 commit comments