1919https://github.com/deepmind/acme/blob/master/acme/jax/running_statistics.py
2020"""
2121
22- from typing import Any , Optional , Tuple
22+ from typing import Optional , Tuple , Union
2323
24+ from brax .training import types as training_types
2425from brax .training .acme import types
2526from flax import struct
2627import jax
@@ -45,21 +46,28 @@ class NestedMeanStd:
4546@struct .dataclass
4647class RunningStatisticsState (NestedMeanStd ):
4748 """Full state of running statistics computation."""
48- count : jnp .ndarray
49+ count : Union [ jnp .ndarray , training_types . UInt64 ]
4950 summed_variance : types .Nest
51+ std_eps : float = 0.0
5052
5153
52- def init_state (nest : types .Nest ) -> RunningStatisticsState :
53- """Initializes the running statistics for the given nested structure."""
54+ def init_state (nest : types .Nest , std_eps : float = 0.0 ) -> RunningStatisticsState :
55+ """Initializes the running statistics for the given nested structure.
56+
57+ Args:
58+ nest: Nested structure to initialize statistics for.
59+ std_eps: Epsilon for numerical stability when getting std.
60+ """
5461 dtype = jnp .float64 if jax .config .jax_enable_x64 else jnp .float32
5562
5663 return RunningStatisticsState (
57- count = jnp . zeros ((), dtype = dtype ),
64+ count = training_types . UInt64 ( hi = 0 , lo = 0 ),
5865 mean = _zeros_like (nest , dtype = dtype ),
5966 summed_variance = _zeros_like (nest , dtype = dtype ),
6067 # Initialize with ones to make sure normalization works correctly
6168 # in the initial state.
62- std = _ones_like (nest , dtype = dtype ))
69+ std = _ones_like (nest , dtype = dtype ),
70+ std_eps = std_eps )
6371
6472
6573def _validate_batch_shapes (batch : types .NestedArray ,
@@ -99,10 +107,10 @@ def update(state: RunningStatisticsState,
99107
100108 Note: data batch and state elements (mean, etc.) must have the same structure.
101109
102- Note: by default will use int32 for counts and float32 for accumulated
103- variance. This results in an integer overflow after 2^31 data points and
104- degrading precision after 2^24 batch updates or even earlier if variance
105- updates have large dynamic range .
110+ Note: by default uses UInt64 for counts that get converted to float32 for division.
111+ This conversion has a small precision loss for large counts. float32 is used
112+ to accumulate variance, so can also suffer from precision loss due to the 24 bit
113+ mantissa for float32 .
106114 To improve precision, consider setting jax_enable_x64 to True, see
107115 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
108116
@@ -133,13 +141,21 @@ def update(state: RunningStatisticsState,
133141 jax .tree_util .tree_leaves (state .mean )[0 ].ndim ]
134142 batch_axis = range (len (batch_dims ))
135143 if weights is None :
136- step_increment = jnp .prod (jnp .array (batch_dims ))
144+ step_increment = jnp .prod (jnp .array (batch_dims )). astype ( jnp . int32 )
137145 else :
138- step_increment = jnp .sum (weights )
146+ step_increment = jnp .sum (weights ). astype ( jnp . int32 )
139147 if pmap_axis_name is not None :
140148 step_increment = jax .lax .psum (step_increment , axis_name = pmap_axis_name )
141149 count = state .count + step_increment
142150
151+ if isinstance (count , training_types .UInt64 ):
152+ # Convert UInt64 count to float32 for division operations.
153+ # Note: small precision loss due to float32's 24-bit mantissa.
154+ count_float = (jnp .float32 (count .hi ) * jnp .float32 (2.0 ** 32 ) +
155+ jnp .float32 (count .lo ))
156+ else :
157+ count_float = jnp .float32 (count )
158+
143159 # Validation is important. If the shapes don't match exactly, but are
144160 # compatible, arrays will be silently broadcasted resulting in incorrect
145161 # statistics.
@@ -162,7 +178,7 @@ def _compute_node_statistics(
162178 weights ,
163179 list (weights .shape ) + [1 ] * (batch .ndim - weights .ndim ))
164180 diff_to_old_mean = diff_to_old_mean * expanded_weights
165- mean_update = jnp .sum (diff_to_old_mean , axis = batch_axis ) / count
181+ mean_update = jnp .sum (diff_to_old_mean , axis = batch_axis ) / count_float
166182 if pmap_axis_name is not None :
167183 mean_update = jax .lax .psum (
168184 mean_update , axis_name = pmap_axis_name )
@@ -188,14 +204,19 @@ def compute_std(summed_variance: jnp.ndarray,
188204 assert isinstance (summed_variance , jnp .ndarray )
189205 # Summed variance can get negative due to rounding errors.
190206 summed_variance = jnp .maximum (summed_variance , 0 )
191- std = jnp .sqrt (summed_variance / count )
207+ std = jnp .sqrt (summed_variance / count_float + state . std_eps )
192208 std = jnp .clip (std , std_min_value , std_max_value )
193209 return std
194210
195211 std = jax .tree_util .tree_map (compute_std , summed_variance , state .std )
196212
197213 return RunningStatisticsState (
198- count = count , mean = mean , summed_variance = summed_variance , std = std )
214+ count = count ,
215+ mean = mean ,
216+ summed_variance = summed_variance ,
217+ std = std ,
218+ std_eps = state .std_eps ,
219+ )
199220
200221
201222def normalize (batch : types .NestedArray ,
0 commit comments