File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -44,9 +44,12 @@ def __init__(
4444 def update_episode_metrics (self , episode_metrics , dones , train_metrics ):
4545 self ._num_steps += np .prod (dones .shape )
4646 if jnp .sum (dones ) > 0 :
47+ lengths = episode_metrics ['length' ][dones .astype (bool )].flatten ()
4748 for name , metric in episode_metrics .items ():
48- done_metrics = metric [dones .astype (bool )].flatten ().tolist ()
49- self ._ep_metrics_buffer [name ].extend (done_metrics )
49+ done_metrics = metric [dones .astype (bool )].flatten ()
50+ if name .endswith ('_per_step' ):
51+ done_metrics = done_metrics / (lengths + 1e-8 )
52+ self ._ep_metrics_buffer [name ].extend (done_metrics .tolist ())
5053 for name , metric in train_metrics .items ():
5154 self ._train_metrics_buffer [name ].extend (metric .flatten ().tolist ())
5255 if self ._num_steps - self ._last_log_steps >= self ._steps_between_logging :
You can’t perform that action at this time.
0 commit comments