Skip to content

Commit 12c29bd

Browse files
btabacopybara-github
authored andcommitted
Hook into per_step metrics for training.
PiperOrigin-RevId: 836502152 Change-Id: I03026537261f530e06aa9bede07da9e1437a05ea
1 parent 853f207 commit 12c29bd

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

brax/training/logger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ def __init__(
4444
def update_episode_metrics(self, episode_metrics, dones, train_metrics):
4545
self._num_steps += np.prod(dones.shape)
4646
if jnp.sum(dones) > 0:
47+
lengths = episode_metrics['length'][dones.astype(bool)].flatten()
4748
for name, metric in episode_metrics.items():
48-
done_metrics = metric[dones.astype(bool)].flatten().tolist()
49-
self._ep_metrics_buffer[name].extend(done_metrics)
49+
done_metrics = metric[dones.astype(bool)].flatten()
50+
if name.endswith('_per_step'):
51+
done_metrics = done_metrics / (lengths + 1e-8)
52+
self._ep_metrics_buffer[name].extend(done_metrics.tolist())
5053
for name, metric in train_metrics.items():
5154
self._train_metrics_buffer[name].extend(metric.flatten().tolist())
5255
if self._num_steps - self._last_log_steps >= self._steps_between_logging:

0 commit comments

Comments
 (0)