Skip to content

Commit e8c9c18

Browse files
btabacopybara-github
authored andcommitted
Bootstrap on timeout.
PiperOrigin-RevId: 844813746 Change-Id: I0d2d18892ff41e9e1597f6bb0ebcf36d9020c9a0
1 parent a6b0c6b commit e8c9c18

File tree

5 files changed

+46
-8
lines changed

5 files changed

+46
-8
lines changed

brax/envs/inverted_pendulum.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ def reset(self, rng: jax.Array) -> State:
124124
obs = self._get_obs(pipeline_state)
125125
reward, done = jp.zeros(2)
126126
metrics = {}
127+
info = {'time_out': done} # allows bootstrap_on_timeout for PPO
127128

128-
return State(pipeline_state, obs, reward, done, metrics)
129+
return State(pipeline_state, obs, reward, done, metrics, info)
129130

130131
def step(self, state: State, action: jax.Array) -> State:
131132
"""Run one timestep of the environment's dynamics."""
@@ -140,7 +141,8 @@ def step(self, state: State, action: jax.Array) -> State:
140141
reward = 1.0
141142
done = jp.where(jp.abs(obs[1]) > 0.2, 1.0, 0.0)
142143
return state.replace(
143-
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
144+
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done,
145+
info={**state.info, 'time_out': done}
144146
)
145147

146148
@property

brax/training/agents/ppo/networks.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ class PPONetworks:
3232
parametric_action_distribution: distribution.ParametricDistribution
3333

3434

35-
def make_inference_fn(ppo_networks: PPONetworks):
36-
"""Creates params and inference function for the PPO agent."""
35+
def make_inference_fn(ppo_networks: PPONetworks, compute_value: bool = False):
36+
"""Creates params and inference function for the PPO agent.
37+
38+
Args:
39+
ppo_networks: The PPO networks.
40+
compute_value: If True, compute value during rollouts.
41+
"""
3742

3843
def make_policy(
3944
params: types.Params, deterministic: bool = False
@@ -55,11 +60,16 @@ def policy(
5560
postprocessed_actions = parametric_action_distribution.postprocess(
5661
raw_actions
5762
)
58-
return postprocessed_actions, {
63+
extras = {
5964
'log_prob': log_prob,
6065
'raw_action': raw_actions,
6166
'distribution_params': logits,
6267
}
68+
if compute_value:
69+
extras['value'] = ppo_networks.value_network.apply(
70+
params[0], params[2], observations
71+
)
72+
return postprocessed_actions, extras
6373

6474
return policy
6575

brax/training/agents/ppo/train.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def train(
224224
max_grad_norm: Optional[float] = None,
225225
normalize_advantage: bool = True,
226226
vf_loss_coefficient: float = 0.5,
227+
bootstrap_on_timeout: bool = False,
227228
desired_kl: float = 0.01,
228229
learning_rate_schedule: Optional[
229230
Union[str, ppo_optimizer.LRSchedule]
@@ -299,6 +300,10 @@ def train(
299300
max_grad_norm: gradient clipping norm value. If None, no clipping is done
300301
normalize_advantage: whether to normalize advantage estimate
301302
vf_loss_coefficient: Coefficient for value function loss.
303+
bootstrap_on_timeout: if True, bootstrap value on time_out steps using
304+
reward += gamma * V(s) * time_out. Environments should set
305+
state.info['time_out'] = 1.0 and done=True for steps where the episode ends
306+
due to a time_out.
302307
desired_kl: Desired KL divergence for adaptive KL divergence learning rate
303308
schedule.
304309
learning_rate_schedule: Learning rate schedule for the optimizer.
@@ -431,7 +436,9 @@ def reset_fn_donated_env_state(env_state_donated, key_envs):
431436
ppo_network = network_factory(
432437
obs_shape, env.action_size, preprocess_observations_fn=normalize
433438
)
434-
make_policy = ppo_networks.make_inference_fn(ppo_network)
439+
make_policy = ppo_networks.make_inference_fn(
440+
ppo_network, compute_value=bootstrap_on_timeout
441+
)
435442

436443
# Optimizer.
437444
base_optimizer = optax.adam(learning_rate=learning_rate)
@@ -551,13 +558,16 @@ def training_step(
551558
def f(carry, unused_t):
552559
current_state, current_key = carry
553560
current_key, next_key = jax.random.split(current_key)
561+
extra_fields = ['truncation', 'episode_metrics', 'episode_done']
562+
if bootstrap_on_timeout:
563+
extra_fields.append('time_out')
554564
next_state, data = acting.generate_unroll(
555565
env,
556566
current_state,
557567
policy,
558568
current_key,
559569
unroll_length,
560-
extra_fields=('truncation', 'episode_metrics', 'episode_done'),
570+
extra_fields=tuple(extra_fields),
561571
)
562572
return (next_state, next_key), data
563573

@@ -574,6 +584,18 @@ def f(carry, unused_t):
574584
)
575585
assert data.discount.shape[1:] == (unroll_length,)
576586

587+
if bootstrap_on_timeout: # bootstrap reward on timeout
588+
time_out = data.extras['state_extras']['time_out']
589+
value = data.extras['policy_extras']['value']
590+
data = types.Transition(
591+
observation=data.observation,
592+
action=data.action,
593+
reward=data.reward + discounting * time_out * value,
594+
discount=data.discount,
595+
next_observation=data.next_observation,
596+
extras=data.extras,
597+
)
598+
577599
normalizer_params = training_state.normalizer_params
578600
if not lr_is_adaptive_kl:
579601
# Update normalization params before SGD for backwards compatibility.

brax/training/agents/ppo/train_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ def testTrain(self, obs_mode):
6262
dict(distribution_type='tanh_normal', noise_std_type='log'),
6363
),
6464
normalize_mode=['welford', 'ema'],
65+
bootstrap_on_timeout=[True, False],
6566
)
6667
def testTrainWithNetworkParams(
67-
self, distribution_type, noise_std_type, normalize_mode
68+
self, distribution_type, noise_std_type, normalize_mode,
69+
bootstrap_on_timeout
6870
):
6971
"""Test PPO runs with different network params."""
7072
network_factory = functools.partial(
@@ -99,6 +101,7 @@ def testTrainWithNetworkParams(
99101
network_factory=network_factory,
100102
learning_rate_schedule='ADAPTIVE_KL',
101103
normalize_observations_mode=normalize_mode,
104+
bootstrap_on_timeout=bootstrap_on_timeout,
102105
)
103106

104107
def testTrainAsymmetricActorCritic(self):

docs/release-notes/next-release.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
* Add `donate_argnums` to brax PPO to somewhat mitigate repeated graph captures when using MJX-Warp.
1010
* Add `normalize_observations_mode` to PPO to allow using EMA for running statistics instead of Welford. EMA is more stable for longer training runs.
1111
* Fix bug in PPO training metric logging frequency for multi-GPU devices.
12+
* Add value bootstrap on `timeout` for PPO. `reward += gamma * V(s) * time_out` if `bootstrap_on_timeout` is set to True.

0 commit comments

Comments
 (0)