Skip to content

Commit ca7649c

Browse files
btabacopybara-github
authored andcommitted
Fix counting issue in running statistics.
PiperOrigin-RevId: 841283704 Change-Id: Ie2cf7e925293e89433609000c8876d70b0be8ce5
1 parent 12c29bd commit ca7649c

File tree

5 files changed

+57
-20
lines changed

5 files changed

+57
-20
lines changed

brax/training/acme/running_statistics.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
https://github.com/deepmind/acme/blob/master/acme/jax/running_statistics.py
2020
"""
2121

22-
from typing import Any, Optional, Tuple
22+
from typing import Optional, Tuple, Union
2323

24+
from brax.training import types as training_types
2425
from brax.training.acme import types
2526
from flax import struct
2627
import jax
@@ -45,21 +46,28 @@ class NestedMeanStd:
4546
@struct.dataclass
4647
class RunningStatisticsState(NestedMeanStd):
4748
"""Full state of running statistics computation."""
48-
count: jnp.ndarray
49+
count: Union[jnp.ndarray, training_types.UInt64]
4950
summed_variance: types.Nest
51+
std_eps: float = 0.0
5052

5153

52-
def init_state(nest: types.Nest) -> RunningStatisticsState:
53-
"""Initializes the running statistics for the given nested structure."""
54+
def init_state(nest: types.Nest, std_eps: float = 0.0) -> RunningStatisticsState:
55+
"""Initializes the running statistics for the given nested structure.
56+
57+
Args:
58+
nest: Nested structure to initialize statistics for.
59+
std_eps: Epsilon for numerical stability when getting std.
60+
"""
5461
dtype = jnp.float64 if jax.config.jax_enable_x64 else jnp.float32
5562

5663
return RunningStatisticsState(
57-
count=jnp.zeros((), dtype=dtype),
64+
count=training_types.UInt64(hi=0, lo=0),
5865
mean=_zeros_like(nest, dtype=dtype),
5966
summed_variance=_zeros_like(nest, dtype=dtype),
6067
# Initialize with ones to make sure normalization works correctly
6168
# in the initial state.
62-
std=_ones_like(nest, dtype=dtype))
69+
std=_ones_like(nest, dtype=dtype),
70+
std_eps=std_eps)
6371

6472

6573
def _validate_batch_shapes(batch: types.NestedArray,
@@ -99,10 +107,10 @@ def update(state: RunningStatisticsState,
99107
100108
Note: data batch and state elements (mean, etc.) must have the same structure.
101109
102-
Note: by default will use int32 for counts and float32 for accumulated
103-
variance. This results in an integer overflow after 2^31 data points and
104-
degrading precision after 2^24 batch updates or even earlier if variance
105-
updates have large dynamic range.
110+
Note: by default uses UInt64 for counts that get converted to float32 for division.
111+
This conversion has a small precision loss for large counts. float32 is used
112+
to accumulate variance, so can also suffer from precision loss due to the 24 bit
113+
mantissa for float32.
106114
To improve precision, consider setting jax_enable_x64 to True, see
107115
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
108116
@@ -133,13 +141,21 @@ def update(state: RunningStatisticsState,
133141
jax.tree_util.tree_leaves(state.mean)[0].ndim]
134142
batch_axis = range(len(batch_dims))
135143
if weights is None:
136-
step_increment = jnp.prod(jnp.array(batch_dims))
144+
step_increment = jnp.prod(jnp.array(batch_dims)).astype(jnp.int32)
137145
else:
138-
step_increment = jnp.sum(weights)
146+
step_increment = jnp.sum(weights).astype(jnp.int32)
139147
if pmap_axis_name is not None:
140148
step_increment = jax.lax.psum(step_increment, axis_name=pmap_axis_name)
141149
count = state.count + step_increment
142150

151+
if isinstance(count, training_types.UInt64):
152+
# Convert UInt64 count to float32 for division operations.
153+
# Note: small precision loss due to float32's 24-bit mantissa.
154+
count_float = (jnp.float32(count.hi) * jnp.float32(2.0**32) +
155+
jnp.float32(count.lo))
156+
else:
157+
count_float = jnp.float32(count)
158+
143159
# Validation is important. If the shapes don't match exactly, but are
144160
# compatible, arrays will be silently broadcasted resulting in incorrect
145161
# statistics.
@@ -162,7 +178,7 @@ def _compute_node_statistics(
162178
weights,
163179
list(weights.shape) + [1] * (batch.ndim - weights.ndim))
164180
diff_to_old_mean = diff_to_old_mean * expanded_weights
165-
mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count
181+
mean_update = jnp.sum(diff_to_old_mean, axis=batch_axis) / count_float
166182
if pmap_axis_name is not None:
167183
mean_update = jax.lax.psum(
168184
mean_update, axis_name=pmap_axis_name)
@@ -188,14 +204,19 @@ def compute_std(summed_variance: jnp.ndarray,
188204
assert isinstance(summed_variance, jnp.ndarray)
189205
# Summed variance can get negative due to rounding errors.
190206
summed_variance = jnp.maximum(summed_variance, 0)
191-
std = jnp.sqrt(summed_variance / count)
207+
std = jnp.sqrt(summed_variance / count_float + state.std_eps)
192208
std = jnp.clip(std, std_min_value, std_max_value)
193209
return std
194210

195211
std = jax.tree_util.tree_map(compute_std, summed_variance, state.std)
196212

197213
return RunningStatisticsState(
198-
count=count, mean=mean, summed_variance=summed_variance, std=std)
214+
count=count,
215+
mean=mean,
216+
summed_variance=summed_variance,
217+
std=std,
218+
std_eps=state.std_eps,
219+
)
199220

200221

201222
def normalize(batch: types.NestedArray,

brax/training/agents/ppo/checkpoint_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def test_save_and_load_checkpoint(self):
8585
value=ppo_network.value_network.init(dummy_key),
8686
)
8787
normalizer_params = running_statistics.init_state(
88-
jax.tree_util.tree_map(jp.zeros, config.observation_size)
88+
jax.tree_util.tree_map(jp.zeros, config.observation_size),
89+
std_eps=0.02,
8990
)
9091
params = (normalizer_params, network_params.policy, network_params.value)
9192

@@ -103,6 +104,10 @@ def test_save_and_load_checkpoint(self):
103104
out = policy_fn(jp.zeros(1), jax.random.PRNGKey(0))
104105
self.assertEqual(out[0].shape, (3,))
105106

107+
loaded_params = checkpoint.load(epath.Path(path.full_path) / "000000000001")
108+
loaded_normalizer = loaded_params[0]
109+
self.assertEqual(loaded_normalizer.std_eps, 0.02)
110+
106111

107112
if __name__ == "__main__":
108113
absltest.main()

brax/training/agents/ppo/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def train(
216216
num_updates_per_batch: int = 2,
217217
num_resets_per_eval: int = 0,
218218
normalize_observations: bool = False,
219+
normalize_observations_std_eps: float = 0.0,
219220
reward_scaling: float = 1.0,
220221
clipping_epsilon: float = 0.3,
221222
gae_lambda: float = 0.95,
@@ -287,6 +288,8 @@ def train(
287288
num_resets_per_eval: the number of environment resets to run between each
288289
eval. The environment resets occur on the host
289290
normalize_observations: whether to normalize observations
291+
normalize_observations_std_eps: small value added to the standard deviation
292+
for obs normalization to improve numerical stability
290293
reward_scaling: float scaling for reward
291294
clipping_epsilon: clipping epsilon for PPO loss
292295
gae_lambda: General advantage estimation lambda
@@ -672,7 +675,7 @@ def training_epoch_with_timing(
672675
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
673676
params=init_params,
674677
normalizer_params=running_statistics.init_state(
675-
_remove_pixels(obs_shape)
678+
_remove_pixels(obs_shape), std_eps=normalize_observations_std_eps
676679
),
677680
env_steps=types.UInt64(hi=0, lo=0),
678681
)

brax/training/checkpoint.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ def load(
193193
target = orbax_checkpointer.restore(
194194
path, ocp.args.PyTreeRestore(restore_args=restore_args), item=None
195195
)
196-
target[0] = running_statistics.RunningStatisticsState(**target[0])
196+
197+
# Reconstruct UInt64 count if it was saved as dict.
198+
state_dict = target[0]
199+
if isinstance(state_dict['count'], dict) and 'hi' in state_dict['count']:
200+
state_dict['count'] = types.UInt64(**state_dict['count'])
201+
target[0] = running_statistics.RunningStatisticsState(**state_dict)
197202

198203
return target
199204

brax/training/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,11 @@ def to_numpy(self):
113113

114114
def __post_init__(self):
115115
"""Cast post init."""
116-
object.__setattr__(self, "hi", jnp.uint32(self.hi))
117-
object.__setattr__(self, "lo", jnp.uint32(self.lo))
116+
# Only convert known types - avoids issues with checkpoint serialization.
117+
if isinstance(self.hi, (int, np.integer, np.ndarray, jax.Array)):
118+
object.__setattr__(self, "hi", jnp.uint32(self.hi))
119+
if isinstance(self.lo, (int, np.integer, np.ndarray, jax.Array)):
120+
object.__setattr__(self, "lo", jnp.uint32(self.lo))
118121

119122
def __add__(self, other):
120123
other = _sanitize_uint64_input(other)

0 commit comments

Comments
 (0)