Population Based Training (PBT)

PBT is optimization method to jointly optimise a population of models and their hyperparameters to maximize performance.

To use PBT in QDax to train SAC, one can use the two following components (see examples to see how to use the components appropriately):

Bases: SAC

Source code in qdax/baselines/sac_pbt.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class PBTSAC(SAC):
    def __init__(self, config: PBTSacConfig, action_size: int) -> None:

        sac_config = SacConfig(
            batch_size=config.batch_size,
            episode_length=config.episode_length,
            tau=config.tau,
            normalize_observations=config.normalize_observations,
            alpha_init=config.alpha_init,
            policy_hidden_layer_size=config.policy_hidden_layer_size,
            critic_hidden_layer_size=config.critic_hidden_layer_size,
            fix_alpha=config.fix_alpha,
            # unused default values for parameters that will be learnt as part of PBT
            learning_rate=3e-4,
            discount=0.97,
            reward_scaling=1.0,
        )
        SAC.__init__(self, config=sac_config, action_size=action_size)

    def init(
        self, key: RNGKey, action_size: int, observation_size: int
    ) -> PBTSacTrainingState:
        """Initialise the training state of the algorithm.

        Args:
            key: a jax random key
            action_size: the size of the environment's action space
            observation_size: the size of the environment's observation space

        Returns:
            the initial training state of PBT-SAC
        """

        sac_training_state = SAC.init(self, key, action_size, observation_size)

        training_state = PBTSacTrainingState(
            policy_optimizer_state=sac_training_state.policy_optimizer_state,
            policy_params=sac_training_state.policy_params,
            critic_optimizer_state=sac_training_state.critic_optimizer_state,
            critic_params=sac_training_state.critic_params,
            alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
            alpha_params=sac_training_state.alpha_params,
            target_critic_params=sac_training_state.target_critic_params,
            normalization_running_stats=sac_training_state.normalization_running_stats,
            key=sac_training_state.key,
            steps=sac_training_state.steps,
            discount=None,
            policy_lr=None,
            critic_lr=None,
            alpha_lr=None,
            reward_scaling=None,
        )

        # Sample hyper-params
        training_state = PBTSacTrainingState.resample_hyperparams(training_state)

        return training_state  # type: ignore

    def update(
        self,
        training_state: PBTSacTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
        """Performs a training step to update the policy and the critic parameters.

        Args:
            training_state: the current PBT-SAC training state
            replay_buffer: the replay buffer

        Returns:
            the updated PBT-SAC training state
            the replay buffer
            the training metrics
        """

        # sample a batch of transitions in the buffer
        key = training_state.key

        key, subkey = jax.random.split(key)
        transitions = replay_buffer.sample(
            subkey,
            sample_size=self._config.batch_size,
        )

        # normalise observations if necessary
        if self._config.normalize_observations:
            normalization_running_stats = training_state.normalization_running_stats
            normalized_obs = normalize_with_rmstd(
                transitions.obs, normalization_running_stats
            )
            normalized_next_obs = normalize_with_rmstd(
                transitions.next_obs, normalization_running_stats
            )
            transitions = transitions.replace(
                obs=normalized_obs, next_obs=normalized_next_obs
            )

        # update alpha
        key, subkey = jax.random.split(key)
        (
            alpha_params,
            alpha_optimizer_state,
            alpha_loss,
        ) = self._update_alpha(
            alpha_lr=training_state.alpha_lr,
            training_state=training_state,
            transitions=transitions,
            key=subkey,
        )

        # update critic
        key, subkey = jax.random.split(key)
        (
            critic_params,
            target_critic_params,
            critic_optimizer_state,
            critic_loss,
        ) = self._update_critic(
            critic_lr=training_state.critic_lr,
            reward_scaling=training_state.reward_scaling,
            discount=training_state.discount,
            training_state=training_state,
            transitions=transitions,
            key=subkey,
        )

        # update actor
        key, subkey = jax.random.split(key)
        (
            policy_params,
            policy_optimizer_state,
            policy_loss,
        ) = self._update_actor(
            policy_lr=training_state.policy_lr,
            training_state=training_state,
            transitions=transitions,
            key=subkey,
        )

        # create new training state
        key, subkey = jax.random.split(key)
        new_training_state = PBTSacTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=alpha_params,
            normalization_running_stats=training_state.normalization_running_stats,
            target_critic_params=target_critic_params,
            key=subkey,
            steps=training_state.steps + 1,
            discount=training_state.discount,
            policy_lr=training_state.policy_lr,
            critic_lr=training_state.critic_lr,
            alpha_lr=training_state.alpha_lr,
            reward_scaling=training_state.reward_scaling,
        )
        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
            "alpha_loss": alpha_loss,
            "obs_mean": jnp.mean(transitions.obs),
            "obs_std": jnp.std(transitions.obs),
        }
        return new_training_state, replay_buffer, metrics

    def get_init_fn(
        self,
        population_size: int,
        action_size: int,
        observation_size: int,
        buffer_size: int,
    ) -> Callable:
        """
        Returns a function to initialize the population.

        Args:
            population_size: size of the population.
            action_size: action space size.
            observation_size: observation space size.
            buffer_size: replay buffer size.

        Returns:
            a function that takes as input a random key and returns a new random
            key, the PBT population training state and the replay buffers
        """

        def _init_fn(
            key: RNGKey,
        ) -> Tuple[PBTSacTrainingState, ReplayBuffer]:

            key, *keys = jax.random.split(key, num=population_size + 1)
            keys = jnp.stack(keys)

            init_dummy_transition = partial(
                Transition.init_dummy,
                observation_dim=observation_size,
                action_dim=action_size,
            )
            init_dummy_transition = jax.vmap(
                init_dummy_transition, axis_size=population_size
            )
            dummy_transitions = init_dummy_transition()

            replay_buffer_init = partial(
                ReplayBuffer.init,
                buffer_size=buffer_size,
            )
            replay_buffer_init = jax.vmap(replay_buffer_init)
            replay_buffers = replay_buffer_init(transition=dummy_transitions)
            agent_init = partial(
                self.init, action_size=action_size, observation_size=observation_size
            )
            training_states = jax.vmap(agent_init)(keys)
            return training_states, replay_buffers

        return _init_fn

    def get_eval_fn(
        self,
        eval_env: Env,
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns over episodes as well as all returns from all
            agents over all episodes.
        """
        play_eval_step = partial(
            self.play_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_policy_fn,
            play_step_fn=play_eval_step,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_eval_qd_fn(
        self,
        eval_env: Env,
        descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
    ) -> Callable:
        """
        Returns the evaluation function of the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.
            descriptor_extraction_fn: function to extract the descriptor from an
                episode.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns and mean descriptors over episodes,
            as well as allreturns and descriptors from all agents over all episodes.
        """
        play_eval_step = partial(
            self.play_qd_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_qd_policy_fn,
            play_step_fn=play_eval_step,
            descriptor_extraction_fn=descriptor_extraction_fn,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_train_fn(
        self,
        env: Env,
        num_iterations: int,
        env_batch_size: int,
        grad_updates_per_step: float,
    ) -> Callable:
        """
        Returns the function to update the population of agents.

        Args:
            env: training environment.
            num_iterations: number of training iterations to perform.
            env_batch_size: number of batched environments.
            grad_updates_per_step: number of gradient to apply per step in the
                environment.

        Returns:
            the function to update the population which takes as input the population
            training state, environment starting states and replay buffers and returns
            updated training states, environment states, replay buffers and metrics.
        """
        play_step = partial(
            self.play_step_fn,
            env=env,
            deterministic=False,
        )

        do_iteration = partial(
            do_iteration_fn,
            env_batch_size=env_batch_size,
            grad_updates_per_step=grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=self.update,
        )

        def _scan_do_iteration(
            carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
            unused_arg: Any,
        ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
            (
                training_state,
                env_state,
                replay_buffer,
                metrics,
            ) = do_iteration(*carry)
            return (training_state, env_state, replay_buffer), metrics

        def train_fn(
            training_state: PBTSacTrainingState,
            env_state: EnvState,
            replay_buffer: ReplayBuffer,
        ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
            (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
                _scan_do_iteration,
                (training_state, env_state, replay_buffer),
                None,
                length=num_iterations,
            )
            return (training_state, env_state, replay_buffer), metrics

        return jax.vmap(train_fn)  # type: ignore

get_eval_fn(eval_env)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) –

    evaluation environment. Might be different from training env if needed.

Returns:
  • Callable

    The function to evaluate the population. It takes as input the population

  • Callable

    training state as well as first eval environment states and returns the

  • Callable

    population agents mean returns over episodes as well as all returns from all

  • Callable

    agents over all episodes.

Source code in qdax/baselines/sac_pbt.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def get_eval_fn(
    self,
    eval_env: Env,
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns over episodes as well as all returns from all
        agents over all episodes.
    """
    play_eval_step = partial(
        self.play_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_policy_fn,
        play_step_fn=play_eval_step,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_eval_qd_fn(eval_env, descriptor_extraction_fn)

Returns the evaluation function of the PBT population.

Parameters:
  • eval_env (Env) –

    evaluation environment. Might be different from training env if needed.

  • descriptor_extraction_fn (Callable[[QDTransition, Mask], Descriptor]) –

    function to extract the descriptor from an episode.

Returns:
  • Callable

    The function to evaluate the population. It takes as input the population

  • Callable

    training state as well as first eval environment states and returns the

  • Callable

    population agents mean returns and mean descriptors over episodes,

  • Callable

    as well as allreturns and descriptors from all agents over all episodes.

Source code in qdax/baselines/sac_pbt.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def get_eval_qd_fn(
    self,
    eval_env: Env,
    descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
    """
    Returns the evaluation function of the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.
        descriptor_extraction_fn: function to extract the descriptor from an
            episode.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns and mean descriptors over episodes,
        as well as allreturns and descriptors from all agents over all episodes.
    """
    play_eval_step = partial(
        self.play_qd_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_qd_policy_fn,
        play_step_fn=play_eval_step,
        descriptor_extraction_fn=descriptor_extraction_fn,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_init_fn(population_size, action_size, observation_size, buffer_size)

Returns a function to initialize the population.

Parameters:
  • population_size (int) –

    size of the population.

  • action_size (int) –

    action space size.

  • observation_size (int) –

    observation space size.

  • buffer_size (int) –

    replay buffer size.

Returns:
  • Callable

    a function that takes as input a random key and returns a new random

  • Callable

    key, the PBT population training state and the replay buffers

Source code in qdax/baselines/sac_pbt.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def get_init_fn(
    self,
    population_size: int,
    action_size: int,
    observation_size: int,
    buffer_size: int,
) -> Callable:
    """
    Returns a function to initialize the population.

    Args:
        population_size: size of the population.
        action_size: action space size.
        observation_size: observation space size.
        buffer_size: replay buffer size.

    Returns:
        a function that takes as input a random key and returns a new random
        key, the PBT population training state and the replay buffers
    """

    def _init_fn(
        key: RNGKey,
    ) -> Tuple[PBTSacTrainingState, ReplayBuffer]:

        key, *keys = jax.random.split(key, num=population_size + 1)
        keys = jnp.stack(keys)

        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=population_size
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)
        agent_init = partial(
            self.init, action_size=action_size, observation_size=observation_size
        )
        training_states = jax.vmap(agent_init)(keys)
        return training_states, replay_buffers

    return _init_fn

get_train_fn(env, num_iterations, env_batch_size, grad_updates_per_step)

Returns the function to update the population of agents.

Parameters:
  • env (Env) –

    training environment.

  • num_iterations (int) –

    number of training iterations to perform.

  • env_batch_size (int) –

    number of batched environments.

  • grad_updates_per_step (float) –

    number of gradient to apply per step in the environment.

Returns:
  • Callable

    the function to update the population which takes as input the population

  • Callable

    training state, environment starting states and replay buffers and returns

  • Callable

    updated training states, environment states, replay buffers and metrics.

Source code in qdax/baselines/sac_pbt.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def get_train_fn(
    self,
    env: Env,
    num_iterations: int,
    env_batch_size: int,
    grad_updates_per_step: float,
) -> Callable:
    """
    Returns the function to update the population of agents.

    Args:
        env: training environment.
        num_iterations: number of training iterations to perform.
        env_batch_size: number of batched environments.
        grad_updates_per_step: number of gradient to apply per step in the
            environment.

    Returns:
        the function to update the population which takes as input the population
        training state, environment starting states and replay buffers and returns
        updated training states, environment states, replay buffers and metrics.
    """
    play_step = partial(
        self.play_step_fn,
        env=env,
        deterministic=False,
    )

    do_iteration = partial(
        do_iteration_fn,
        env_batch_size=env_batch_size,
        grad_updates_per_step=grad_updates_per_step,
        play_step_fn=play_step,
        update_fn=self.update,
    )

    def _scan_do_iteration(
        carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
        unused_arg: Any,
    ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
        (
            training_state,
            env_state,
            replay_buffer,
            metrics,
        ) = do_iteration(*carry)
        return (training_state, env_state, replay_buffer), metrics

    def train_fn(
        training_state: PBTSacTrainingState,
        env_state: EnvState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
        (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
            _scan_do_iteration,
            (training_state, env_state, replay_buffer),
            None,
            length=num_iterations,
        )
        return (training_state, env_state, replay_buffer), metrics

    return jax.vmap(train_fn)  # type: ignore

init(key, action_size, observation_size)

Initialise the training state of the algorithm.

Parameters:
  • key (RNGKey) –

    a jax random key

  • action_size (int) –

    the size of the environment's action space

  • observation_size (int) –

    the size of the environment's observation space

Returns:
  • PBTSacTrainingState

    the initial training state of PBT-SAC

Source code in qdax/baselines/sac_pbt.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def init(
    self, key: RNGKey, action_size: int, observation_size: int
) -> PBTSacTrainingState:
    """Initialise the training state of the algorithm.

    Args:
        key: a jax random key
        action_size: the size of the environment's action space
        observation_size: the size of the environment's observation space

    Returns:
        the initial training state of PBT-SAC
    """

    sac_training_state = SAC.init(self, key, action_size, observation_size)

    training_state = PBTSacTrainingState(
        policy_optimizer_state=sac_training_state.policy_optimizer_state,
        policy_params=sac_training_state.policy_params,
        critic_optimizer_state=sac_training_state.critic_optimizer_state,
        critic_params=sac_training_state.critic_params,
        alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
        alpha_params=sac_training_state.alpha_params,
        target_critic_params=sac_training_state.target_critic_params,
        normalization_running_stats=sac_training_state.normalization_running_stats,
        key=sac_training_state.key,
        steps=sac_training_state.steps,
        discount=None,
        policy_lr=None,
        critic_lr=None,
        alpha_lr=None,
        reward_scaling=None,
    )

    # Sample hyper-params
    training_state = PBTSacTrainingState.resample_hyperparams(training_state)

    return training_state  # type: ignore

update(training_state, replay_buffer)

Performs a training step to update the policy and the critic parameters.

Parameters:
  • training_state (PBTSacTrainingState) –

    the current PBT-SAC training state

  • replay_buffer (ReplayBuffer) –

    the replay buffer

Returns:
  • PBTSacTrainingState

    the updated PBT-SAC training state

  • ReplayBuffer

    the replay buffer

  • Metrics

    the training metrics

Source code in qdax/baselines/sac_pbt.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def update(
    self,
    training_state: PBTSacTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
    """Performs a training step to update the policy and the critic parameters.

    Args:
        training_state: the current PBT-SAC training state
        replay_buffer: the replay buffer

    Returns:
        the updated PBT-SAC training state
        the replay buffer
        the training metrics
    """

    # sample a batch of transitions in the buffer
    key = training_state.key

    key, subkey = jax.random.split(key)
    transitions = replay_buffer.sample(
        subkey,
        sample_size=self._config.batch_size,
    )

    # normalise observations if necessary
    if self._config.normalize_observations:
        normalization_running_stats = training_state.normalization_running_stats
        normalized_obs = normalize_with_rmstd(
            transitions.obs, normalization_running_stats
        )
        normalized_next_obs = normalize_with_rmstd(
            transitions.next_obs, normalization_running_stats
        )
        transitions = transitions.replace(
            obs=normalized_obs, next_obs=normalized_next_obs
        )

    # update alpha
    key, subkey = jax.random.split(key)
    (
        alpha_params,
        alpha_optimizer_state,
        alpha_loss,
    ) = self._update_alpha(
        alpha_lr=training_state.alpha_lr,
        training_state=training_state,
        transitions=transitions,
        key=subkey,
    )

    # update critic
    key, subkey = jax.random.split(key)
    (
        critic_params,
        target_critic_params,
        critic_optimizer_state,
        critic_loss,
    ) = self._update_critic(
        critic_lr=training_state.critic_lr,
        reward_scaling=training_state.reward_scaling,
        discount=training_state.discount,
        training_state=training_state,
        transitions=transitions,
        key=subkey,
    )

    # update actor
    key, subkey = jax.random.split(key)
    (
        policy_params,
        policy_optimizer_state,
        policy_loss,
    ) = self._update_actor(
        policy_lr=training_state.policy_lr,
        training_state=training_state,
        transitions=transitions,
        key=subkey,
    )

    # create new training state
    key, subkey = jax.random.split(key)
    new_training_state = PBTSacTrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_params,
        critic_optimizer_state=critic_optimizer_state,
        critic_params=critic_params,
        alpha_optimizer_state=alpha_optimizer_state,
        alpha_params=alpha_params,
        normalization_running_stats=training_state.normalization_running_stats,
        target_critic_params=target_critic_params,
        key=subkey,
        steps=training_state.steps + 1,
        discount=training_state.discount,
        policy_lr=training_state.policy_lr,
        critic_lr=training_state.critic_lr,
        alpha_lr=training_state.alpha_lr,
        reward_scaling=training_state.reward_scaling,
    )
    metrics = {
        "actor_loss": policy_loss,
        "critic_loss": critic_loss,
        "alpha_loss": alpha_loss,
        "obs_mean": jnp.mean(transitions.obs),
        "obs_std": jnp.std(transitions.obs),
    }
    return new_training_state, replay_buffer, metrics

and

This class serves as a template for algorithm that want to implement the standard Population Based Training (PBT) scheme.

Source code in qdax/baselines/pbt.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class PBT:
    """
    This class serves as a template for algorithm that want to implement the standard
    Population Based Training (PBT) scheme.
    """

    def __init__(
        self,
        population_size: int,
        num_best_to_replace_from: int,
        num_worse_to_replace: int,
    ):
        """

        Args:
            population_size: Size of the PBT population.
            num_best_to_replace_from: Number of top performing individuals to sample
                from when replacing low performers at each iteration.
            num_worse_to_replace: Number of low-performing individuals to replace at
                each iteration.
        """
        if num_best_to_replace_from + num_worse_to_replace > population_size:
            raise ValueError(
                "The sum of best number of individuals to replace "
                "from and worse individuals to replace exceeds the population size."
            )
        self._population_size = population_size
        self._num_best_to_replace_from = num_best_to_replace_from
        self._num_worse_to_replace = num_worse_to_replace

    def update_states_and_buffer(
        self,
        key: RNGKey,
        population_returns: jax.Array,
        training_state: PBTTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTTrainingState, ReplayBuffer]:
        """
        Updates the agents of the population states as well as
        their shared replay buffer.

        Args:
            key: Random key.
            population_returns: Returns of the agents in the populations.
            training_state: The training state of the PBT scheme.
            replay_buffer: Shared replay buffer by the agents.

        Returns:
            Updated PBT training state and updated replay buffer.
        """
        indices_sorted = jax.numpy.argsort(-population_returns)
        best_indices = indices_sorted[: self._num_best_to_replace_from]
        indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

        indices_used_to_replace = jax.random.choice(
            key, best_indices, shape=(self._num_worse_to_replace,), replace=True
        )

        training_state = jax.tree.map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            training_state,
            jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
        )

        replay_buffer = jax.tree.map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            replay_buffer,
            replay_buffer,
        )

        return training_state, replay_buffer

    def update_states_and_buffer_pmap(
        self,
        key: RNGKey,
        population_returns: jax.Array,
        training_state: PBTTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTTrainingState, ReplayBuffer]:
        """
        Updates the agents of the population states as well as
        their shared replay buffer. This is the version of the function to be
        used within jax.pmap. It makes the population is spread over several devices
        and implement a parallel update through communication between the devices.

        Args:
            key: Random RNG key.
            population_returns: Returns of the agents in the populations.
            training_state: The training state of the PBT scheme.
            replay_buffer: Shared replay buffer by the agents.

        Returns:
            Updated random key, updated PBT training state and updated replay buffer.
        """
        indices_sorted = jax.numpy.argsort(-population_returns)
        best_indices = indices_sorted[: self._num_best_to_replace_from]
        indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

        best_individuals, best_buffers, best_returns = jax.tree.map(
            lambda x: x[best_indices],
            (training_state, replay_buffer, population_returns),
        )
        (
            gathered_best_individuals,
            gathered_best_buffers,
            gathered_best_returns,
        ) = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (best_individuals, best_buffers, best_returns),
        )
        pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
        best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]

        key, subkey = jax.random.split(key)
        indices_used_to_replace = jax.random.choice(
            subkey, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
        )

        training_state = jax.tree.map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            training_state,
            jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
                gathered_best_individuals
            ),
        )

        replay_buffer = jax.tree.map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            replay_buffer,
            gathered_best_buffers,
        )

        return training_state, replay_buffer

__init__(population_size, num_best_to_replace_from, num_worse_to_replace)

Parameters:
  • population_size (int) –

    Size of the PBT population.

  • num_best_to_replace_from (int) –

    Number of top performing individuals to sample from when replacing low performers at each iteration.

  • num_worse_to_replace (int) –

    Number of low-performing individuals to replace at each iteration.

Source code in qdax/baselines/pbt.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    population_size: int,
    num_best_to_replace_from: int,
    num_worse_to_replace: int,
):
    """

    Args:
        population_size: Size of the PBT population.
        num_best_to_replace_from: Number of top performing individuals to sample
            from when replacing low performers at each iteration.
        num_worse_to_replace: Number of low-performing individuals to replace at
            each iteration.
    """
    if num_best_to_replace_from + num_worse_to_replace > population_size:
        raise ValueError(
            "The sum of best number of individuals to replace "
            "from and worse individuals to replace exceeds the population size."
        )
    self._population_size = population_size
    self._num_best_to_replace_from = num_best_to_replace_from
    self._num_worse_to_replace = num_worse_to_replace

update_states_and_buffer(key, population_returns, training_state, replay_buffer)

Updates the agents of the population states as well as their shared replay buffer.

Parameters:
  • key (RNGKey) –

    Random key.

  • population_returns (Array) –

    Returns of the agents in the populations.

  • training_state (PBTTrainingState) –

    The training state of the PBT scheme.

  • replay_buffer (ReplayBuffer) –

    Shared replay buffer by the agents.

Returns:
  • Tuple[PBTTrainingState, ReplayBuffer]

    Updated PBT training state and updated replay buffer.

Source code in qdax/baselines/pbt.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def update_states_and_buffer(
    self,
    key: RNGKey,
    population_returns: jax.Array,
    training_state: PBTTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTTrainingState, ReplayBuffer]:
    """
    Updates the agents of the population states as well as
    their shared replay buffer.

    Args:
        key: Random key.
        population_returns: Returns of the agents in the populations.
        training_state: The training state of the PBT scheme.
        replay_buffer: Shared replay buffer by the agents.

    Returns:
        Updated PBT training state and updated replay buffer.
    """
    indices_sorted = jax.numpy.argsort(-population_returns)
    best_indices = indices_sorted[: self._num_best_to_replace_from]
    indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

    indices_used_to_replace = jax.random.choice(
        key, best_indices, shape=(self._num_worse_to_replace,), replace=True
    )

    training_state = jax.tree.map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        training_state,
        jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
    )

    replay_buffer = jax.tree.map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        replay_buffer,
        replay_buffer,
    )

    return training_state, replay_buffer

update_states_and_buffer_pmap(key, population_returns, training_state, replay_buffer)

Updates the agents of the population states as well as their shared replay buffer. This is the version of the function to be used within jax.pmap. It makes the population is spread over several devices and implement a parallel update through communication between the devices.

Parameters:
  • key (RNGKey) –

    Random RNG key.

  • population_returns (Array) –

    Returns of the agents in the populations.

  • training_state (PBTTrainingState) –

    The training state of the PBT scheme.

  • replay_buffer (ReplayBuffer) –

    Shared replay buffer by the agents.

Returns:
  • Tuple[PBTTrainingState, ReplayBuffer]

    Updated random key, updated PBT training state and updated replay buffer.

Source code in qdax/baselines/pbt.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def update_states_and_buffer_pmap(
    self,
    key: RNGKey,
    population_returns: jax.Array,
    training_state: PBTTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTTrainingState, ReplayBuffer]:
    """
    Updates the agents of the population states as well as
    their shared replay buffer. This is the version of the function to be
    used within jax.pmap. It makes the population is spread over several devices
    and implement a parallel update through communication between the devices.

    Args:
        key: Random RNG key.
        population_returns: Returns of the agents in the populations.
        training_state: The training state of the PBT scheme.
        replay_buffer: Shared replay buffer by the agents.

    Returns:
        Updated random key, updated PBT training state and updated replay buffer.
    """
    indices_sorted = jax.numpy.argsort(-population_returns)
    best_indices = indices_sorted[: self._num_best_to_replace_from]
    indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

    best_individuals, best_buffers, best_returns = jax.tree.map(
        lambda x: x[best_indices],
        (training_state, replay_buffer, population_returns),
    )
    (
        gathered_best_individuals,
        gathered_best_buffers,
        gathered_best_returns,
    ) = jax.tree.map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (best_individuals, best_buffers, best_returns),
    )
    pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
    best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]

    key, subkey = jax.random.split(key)
    indices_used_to_replace = jax.random.choice(
        subkey, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
    )

    training_state = jax.tree.map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        training_state,
        jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
            gathered_best_individuals
        ),
    )

    replay_buffer = jax.tree.map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        replay_buffer,
        gathered_best_buffers,
    )

    return training_state, replay_buffer

To use PBT in order to train TD3 agents, please use the PBTTD3 class:

Bases: TD3

Source code in qdax/baselines/td3_pbt.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
class PBTTD3(TD3):
    def __init__(self, config: PBTTD3Config, action_size: int):

        td3_config = TD3Config(
            episode_length=config.episode_length,
            batch_size=config.batch_size,
            policy_delay=config.policy_delay,
            reward_scaling=config.reward_scaling,
            soft_tau_update=config.soft_tau_update,
            critic_hidden_layer_size=config.critic_hidden_layer_size,
            policy_hidden_layer_size=config.policy_hidden_layer_size,
        )
        TD3.__init__(self, td3_config, action_size)

    def init(
        self, key: RNGKey, action_size: int, observation_size: int
    ) -> PBTTD3TrainingState:
        """Initialise the training state of the PBT-TD3 algorithm, through creation
        of optimizer states and params.

        Args:
            key: a random key used for random operations.
            action_size: the size of the action array needed to interact with the
                environment.
            observation_size: the size of the observation array retrieved from the
                environment.

        Returns:
            the initial training state.
        """

        training_state = TD3.init(self, key, action_size, observation_size)

        # Initial training state
        training_state = PBTTD3TrainingState(
            policy_optimizer_state=training_state.policy_optimizer_state,
            policy_params=training_state.policy_params,
            critic_optimizer_state=training_state.critic_optimizer_state,
            critic_params=training_state.critic_params,
            target_policy_params=training_state.target_policy_params,
            target_critic_params=training_state.target_critic_params,
            key=training_state.key,
            steps=training_state.steps,
            discount=None,
            policy_lr=None,
            critic_lr=None,
            noise_clip=None,
            policy_noise=None,
            expl_noise=None,
        )

        # Sample hyperparameters
        training_state = PBTTD3TrainingState.resample_hyperparams(training_state)

        return training_state  # type: ignore

    def play_step_fn(
        self,
        env_state: EnvState,
        training_state: TD3TrainingState,
        env: Env,
        deterministic: bool = False,
    ) -> Tuple[EnvState, TD3TrainingState, Transition]:
        """Plays a step in the environment. Selects an action according to TD3 rule and
        performs the environment step.

        Args:
            env_state: the current environment state
            training_state: the PBT-TD3 training state
            env: the environment
            deterministic: whether to select action in a deterministic way.
                Defaults to False.

        Returns:
            the new environment state
            the new PBT-TD3 training state
            the played transition
        """
        key, subkey = jax.random.split(training_state.key)

        actions = self.select_action(
            obs=env_state.obs,
            policy_params=training_state.policy_params,
            key=subkey,
            expl_noise=training_state.expl_noise,
            deterministic=deterministic,
        )
        training_state = training_state.replace(
            key=key,
        )
        next_env_state = env.step(env_state, actions)
        transition = Transition(
            obs=env_state.obs,
            next_obs=next_env_state.obs,
            rewards=next_env_state.reward,
            dones=next_env_state.done,
            truncations=next_env_state.info["truncation"],
            actions=actions,
        )
        return next_env_state, training_state, transition

    def update(
        self,
        training_state: PBTTD3TrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
        """Performs a single training step: updates policy params and critic params
        through gradient descent.

        Args:
            training_state: the current training state, containing the optimizer states
                and the params of the policy and critic.
            replay_buffer: the replay buffer, filled with transitions experienced in
                the environment.

        Returns:
            A new training state, the buffer with new transitions and metrics about the
            training process.
        """

        # Sample a batch of transitions in the buffer
        key = training_state.key

        key, subkey = jax.random.split(key)
        samples = replay_buffer.sample(subkey, sample_size=self._config.batch_size)

        # Update Critic
        key, subkey = jax.random.split(key)
        critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
            training_state.critic_params,
            target_policy_params=training_state.target_policy_params,
            target_critic_params=training_state.target_critic_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            policy_noise=training_state.policy_noise,
            noise_clip=training_state.noise_clip,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            transitions=samples,
            key=subkey,
        )
        critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
        critic_updates, critic_optimizer_state = critic_optimizer.update(
            critic_gradient, training_state.critic_optimizer_state
        )
        critic_params = optax.apply_updates(
            training_state.critic_params, critic_updates
        )
        # Soft update of target critic network
        target_critic_params = jax.tree.map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            training_state.target_critic_params,
            critic_params,
        )

        # Update policy
        policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
            training_state.policy_params,
            critic_params=training_state.critic_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            transitions=samples,
        )

        def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
            policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
            (
                policy_updates,
                policy_optimizer_state,
            ) = policy_optimizer.update(
                policy_gradient, training_state.policy_optimizer_state
            )
            policy_params = optax.apply_updates(
                training_state.policy_params, policy_updates
            )
            # Soft update of target policy
            target_policy_params = jax.tree.map(
                lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
                + self._config.soft_tau_update * x2,
                training_state.target_policy_params,
                policy_params,
            )
            return policy_params, target_policy_params, policy_optimizer_state

        # Delayed update
        current_policy_state = (
            training_state.policy_params,
            training_state.target_policy_params,
            training_state.policy_optimizer_state,
        )
        policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
            training_state.steps % self._config.policy_delay == 0,
            lambda _: update_policy_step(),
            lambda _: current_policy_state,
            operand=None,
        )

        # Create new training state
        new_training_state = training_state.replace(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            policy_params=policy_params,
            policy_optimizer_state=policy_optimizer_state,
            target_critic_params=target_critic_params,
            target_policy_params=target_policy_params,
            key=key,
            steps=training_state.steps + 1,
        )

        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
        }

        return new_training_state, replay_buffer, metrics

    def get_init_fn(
        self,
        population_size: int,
        action_size: int,
        observation_size: int,
        buffer_size: int,
    ) -> Callable:
        """
        Returns a function to initialize the population.

        Args:
            population_size: size of the population.
            action_size: action space size.
            observation_size: observation space size.
            buffer_size: replay buffer size.

        Returns:
            a function that takes as input a random key and returns a new random
            key, the PBT population training state and the replay buffers
        """

        def _init_fn(
            key: RNGKey,
        ) -> Tuple[PBTTD3TrainingState, ReplayBuffer]:
            key, *keys = jax.random.split(key, num=population_size + 1)
            keys = jnp.stack(keys)

            init_dummy_transition = partial(
                Transition.init_dummy,
                observation_dim=observation_size,
                action_dim=action_size,
            )
            init_dummy_transition = jax.vmap(
                init_dummy_transition, axis_size=population_size
            )
            dummy_transitions = init_dummy_transition()

            replay_buffer_init = partial(
                ReplayBuffer.init,
                buffer_size=buffer_size,
            )
            replay_buffer_init = jax.vmap(replay_buffer_init)
            replay_buffers = replay_buffer_init(transition=dummy_transitions)
            agent_init = partial(
                self.init, action_size=action_size, observation_size=observation_size
            )
            training_states = jax.vmap(agent_init)(keys)
            return training_states, replay_buffers

        return _init_fn

    def get_eval_fn(
        self,
        eval_env: Env,
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns over episodes as well as all returns from all
            agents over all episodes.
        """
        play_eval_step = partial(
            self.play_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_policy_fn,
            play_step_fn=play_eval_step,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_eval_qd_fn(
        self,
        eval_env: Env,
        descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.
            descriptor_extraction_fn: function to extract the descriptor from an
                episode.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns and mean descriptors over episodes,
            as well as all returns and descriptors from all agents over all episodes.
        """
        play_eval_step = partial(
            self.play_qd_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_qd_policy_fn,
            play_step_fn=play_eval_step,
            descriptor_extraction_fn=descriptor_extraction_fn,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_train_fn(
        self,
        env: Env,
        num_iterations: int,
        env_batch_size: int,
        grad_updates_per_step: float,
    ) -> Callable:
        """
        Returns the function to update the population of agents.

        Args:
            env: training environment.
            num_iterations: number of training iterations to perform.
            env_batch_size: number of batched environments.
            grad_updates_per_step: number of gradient to apply per step in the
                environment.

        Returns:
            the function to update the population which takes as input the population
            training state, environment starting states and replay buffers and returns
            updated training states, environment states, replay buffers and metrics.
        """
        play_step = partial(
            self.play_step_fn,
            env=env,
            deterministic=False,
        )

        do_iteration = partial(
            do_iteration_fn,
            env_batch_size=env_batch_size,
            grad_updates_per_step=grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=self.update,
        )

        def _scan_do_iteration(
            carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
            unused_arg: Any,
        ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
            (
                training_state,
                env_state,
                replay_buffer,
                metrics,
            ) = do_iteration(*carry)
            return (training_state, env_state, replay_buffer), metrics

        def train_fn(
            training_state: PBTTD3TrainingState,
            env_state: EnvState,
            replay_buffer: ReplayBuffer,
        ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
            (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
                _scan_do_iteration,
                (training_state, env_state, replay_buffer),
                None,
                length=num_iterations,
            )
            return (training_state, env_state, replay_buffer), metrics

        return jax.vmap(train_fn)  # type: ignore

get_eval_fn(eval_env)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) –

    evaluation environment. Might be different from training env if needed.

Returns:
  • Callable

    The function to evaluate the population. It takes as input the population

  • Callable

    training state as well as first eval environment states and returns the

  • Callable

    population agents mean returns over episodes as well as all returns from all

  • Callable

    agents over all episodes.

Source code in qdax/baselines/td3_pbt.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def get_eval_fn(
    self,
    eval_env: Env,
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns over episodes as well as all returns from all
        agents over all episodes.
    """
    play_eval_step = partial(
        self.play_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_policy_fn,
        play_step_fn=play_eval_step,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_eval_qd_fn(eval_env, descriptor_extraction_fn)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) –

    evaluation environment. Might be different from training env if needed.

  • descriptor_extraction_fn (Callable[[QDTransition, Mask], Descriptor]) –

    function to extract the descriptor from an episode.

Returns:
  • Callable

    The function to evaluate the population. It takes as input the population

  • Callable

    training state as well as first eval environment states and returns the

  • Callable

    population agents mean returns and mean descriptors over episodes,

  • Callable

    as well as all returns and descriptors from all agents over all episodes.

Source code in qdax/baselines/td3_pbt.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def get_eval_qd_fn(
    self,
    eval_env: Env,
    descriptor_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.
        descriptor_extraction_fn: function to extract the descriptor from an
            episode.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns and mean descriptors over episodes,
        as well as all returns and descriptors from all agents over all episodes.
    """
    play_eval_step = partial(
        self.play_qd_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_qd_policy_fn,
        play_step_fn=play_eval_step,
        descriptor_extraction_fn=descriptor_extraction_fn,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_init_fn(population_size, action_size, observation_size, buffer_size)

Returns a function to initialize the population.

Parameters:
  • population_size (int) –

    size of the population.

  • action_size (int) –

    action space size.

  • observation_size (int) –

    observation space size.

  • buffer_size (int) –

    replay buffer size.

Returns:
  • Callable

    a function that takes as input a random key and returns a new random

  • Callable

    key, the PBT population training state and the replay buffers

Source code in qdax/baselines/td3_pbt.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
def get_init_fn(
    self,
    population_size: int,
    action_size: int,
    observation_size: int,
    buffer_size: int,
) -> Callable:
    """
    Returns a function to initialize the population.

    Args:
        population_size: size of the population.
        action_size: action space size.
        observation_size: observation space size.
        buffer_size: replay buffer size.

    Returns:
        a function that takes as input a random key and returns a new random
        key, the PBT population training state and the replay buffers
    """

    def _init_fn(
        key: RNGKey,
    ) -> Tuple[PBTTD3TrainingState, ReplayBuffer]:
        key, *keys = jax.random.split(key, num=population_size + 1)
        keys = jnp.stack(keys)

        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=population_size
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)
        agent_init = partial(
            self.init, action_size=action_size, observation_size=observation_size
        )
        training_states = jax.vmap(agent_init)(keys)
        return training_states, replay_buffers

    return _init_fn

get_train_fn(env, num_iterations, env_batch_size, grad_updates_per_step)

Returns the function to update the population of agents.

Parameters:
  • env (Env) –

    training environment.

  • num_iterations (int) –

    number of training iterations to perform.

  • env_batch_size (int) –

    number of batched environments.

  • grad_updates_per_step (float) –

    number of gradient to apply per step in the environment.

Returns:
  • Callable

    the function to update the population which takes as input the population

  • Callable

    training state, environment starting states and replay buffers and returns

  • Callable

    updated training states, environment states, replay buffers and metrics.

Source code in qdax/baselines/td3_pbt.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def get_train_fn(
    self,
    env: Env,
    num_iterations: int,
    env_batch_size: int,
    grad_updates_per_step: float,
) -> Callable:
    """
    Returns the function to update the population of agents.

    Args:
        env: training environment.
        num_iterations: number of training iterations to perform.
        env_batch_size: number of batched environments.
        grad_updates_per_step: number of gradient to apply per step in the
            environment.

    Returns:
        the function to update the population which takes as input the population
        training state, environment starting states and replay buffers and returns
        updated training states, environment states, replay buffers and metrics.
    """
    play_step = partial(
        self.play_step_fn,
        env=env,
        deterministic=False,
    )

    do_iteration = partial(
        do_iteration_fn,
        env_batch_size=env_batch_size,
        grad_updates_per_step=grad_updates_per_step,
        play_step_fn=play_step,
        update_fn=self.update,
    )

    def _scan_do_iteration(
        carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
        unused_arg: Any,
    ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
        (
            training_state,
            env_state,
            replay_buffer,
            metrics,
        ) = do_iteration(*carry)
        return (training_state, env_state, replay_buffer), metrics

    def train_fn(
        training_state: PBTTD3TrainingState,
        env_state: EnvState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
        (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
            _scan_do_iteration,
            (training_state, env_state, replay_buffer),
            None,
            length=num_iterations,
        )
        return (training_state, env_state, replay_buffer), metrics

    return jax.vmap(train_fn)  # type: ignore

init(key, action_size, observation_size)

Initialise the training state of the PBT-TD3 algorithm, through creation of optimizer states and params.

Parameters:
  • key (RNGKey) –

    a random key used for random operations.

  • action_size (int) –

    the size of the action array needed to interact with the environment.

  • observation_size (int) –

    the size of the observation array retrieved from the environment.

Returns:
  • PBTTD3TrainingState

    the initial training state.

Source code in qdax/baselines/td3_pbt.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def init(
    self, key: RNGKey, action_size: int, observation_size: int
) -> PBTTD3TrainingState:
    """Initialise the training state of the PBT-TD3 algorithm, through creation
    of optimizer states and params.

    Args:
        key: a random key used for random operations.
        action_size: the size of the action array needed to interact with the
            environment.
        observation_size: the size of the observation array retrieved from the
            environment.

    Returns:
        the initial training state.
    """

    training_state = TD3.init(self, key, action_size, observation_size)

    # Initial training state
    training_state = PBTTD3TrainingState(
        policy_optimizer_state=training_state.policy_optimizer_state,
        policy_params=training_state.policy_params,
        critic_optimizer_state=training_state.critic_optimizer_state,
        critic_params=training_state.critic_params,
        target_policy_params=training_state.target_policy_params,
        target_critic_params=training_state.target_critic_params,
        key=training_state.key,
        steps=training_state.steps,
        discount=None,
        policy_lr=None,
        critic_lr=None,
        noise_clip=None,
        policy_noise=None,
        expl_noise=None,
    )

    # Sample hyperparameters
    training_state = PBTTD3TrainingState.resample_hyperparams(training_state)

    return training_state  # type: ignore

play_step_fn(env_state, training_state, env, deterministic=False)

Plays a step in the environment. Selects an action according to TD3 rule and performs the environment step.

Parameters:
  • env_state (State) –

    the current environment state

  • training_state (TD3TrainingState) –

    the PBT-TD3 training state

  • env (Env) –

    the environment

  • deterministic (bool, default: False ) –

    whether to select action in a deterministic way. Defaults to False.

Returns:
  • State

    the new environment state

  • TD3TrainingState

    the new PBT-TD3 training state

  • Transition

    the played transition

Source code in qdax/baselines/td3_pbt.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def play_step_fn(
    self,
    env_state: EnvState,
    training_state: TD3TrainingState,
    env: Env,
    deterministic: bool = False,
) -> Tuple[EnvState, TD3TrainingState, Transition]:
    """Plays a step in the environment. Selects an action according to TD3 rule and
    performs the environment step.

    Args:
        env_state: the current environment state
        training_state: the PBT-TD3 training state
        env: the environment
        deterministic: whether to select action in a deterministic way.
            Defaults to False.

    Returns:
        the new environment state
        the new PBT-TD3 training state
        the played transition
    """
    key, subkey = jax.random.split(training_state.key)

    actions = self.select_action(
        obs=env_state.obs,
        policy_params=training_state.policy_params,
        key=subkey,
        expl_noise=training_state.expl_noise,
        deterministic=deterministic,
    )
    training_state = training_state.replace(
        key=key,
    )
    next_env_state = env.step(env_state, actions)
    transition = Transition(
        obs=env_state.obs,
        next_obs=next_env_state.obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        truncations=next_env_state.info["truncation"],
        actions=actions,
    )
    return next_env_state, training_state, transition

update(training_state, replay_buffer)

Performs a single training step: updates policy params and critic params through gradient descent.

Parameters:
  • training_state (PBTTD3TrainingState) –

    the current training state, containing the optimizer states and the params of the policy and critic.

  • replay_buffer (ReplayBuffer) –

    the replay buffer, filled with transitions experienced in the environment.

Returns:
  • PBTTD3TrainingState

    A new training state, the buffer with new transitions and metrics about the

  • ReplayBuffer

    training process.

Source code in qdax/baselines/td3_pbt.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def update(
    self,
    training_state: PBTTD3TrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
    """Performs a single training step: updates policy params and critic params
    through gradient descent.

    Args:
        training_state: the current training state, containing the optimizer states
            and the params of the policy and critic.
        replay_buffer: the replay buffer, filled with transitions experienced in
            the environment.

    Returns:
        A new training state, the buffer with new transitions and metrics about the
        training process.
    """

    # Sample a batch of transitions in the buffer
    key = training_state.key

    key, subkey = jax.random.split(key)
    samples = replay_buffer.sample(subkey, sample_size=self._config.batch_size)

    # Update Critic
    key, subkey = jax.random.split(key)
    critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
        training_state.critic_params,
        target_policy_params=training_state.target_policy_params,
        target_critic_params=training_state.target_critic_params,
        policy_fn=self._policy.apply,
        critic_fn=self._critic.apply,
        policy_noise=training_state.policy_noise,
        noise_clip=training_state.noise_clip,
        reward_scaling=self._config.reward_scaling,
        discount=self._config.discount,
        transitions=samples,
        key=subkey,
    )
    critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
    critic_updates, critic_optimizer_state = critic_optimizer.update(
        critic_gradient, training_state.critic_optimizer_state
    )
    critic_params = optax.apply_updates(
        training_state.critic_params, critic_updates
    )
    # Soft update of target critic network
    target_critic_params = jax.tree.map(
        lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
        + self._config.soft_tau_update * x2,
        training_state.target_critic_params,
        critic_params,
    )

    # Update policy
    policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
        training_state.policy_params,
        critic_params=training_state.critic_params,
        policy_fn=self._policy.apply,
        critic_fn=self._critic.apply,
        transitions=samples,
    )

    def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
        policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
        (
            policy_updates,
            policy_optimizer_state,
        ) = policy_optimizer.update(
            policy_gradient, training_state.policy_optimizer_state
        )
        policy_params = optax.apply_updates(
            training_state.policy_params, policy_updates
        )
        # Soft update of target policy
        target_policy_params = jax.tree.map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            training_state.target_policy_params,
            policy_params,
        )
        return policy_params, target_policy_params, policy_optimizer_state

    # Delayed update
    current_policy_state = (
        training_state.policy_params,
        training_state.target_policy_params,
        training_state.policy_optimizer_state,
    )
    policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
        training_state.steps % self._config.policy_delay == 0,
        lambda _: update_policy_step(),
        lambda _: current_policy_state,
        operand=None,
    )

    # Create new training state
    new_training_state = training_state.replace(
        critic_params=critic_params,
        critic_optimizer_state=critic_optimizer_state,
        policy_params=policy_params,
        policy_optimizer_state=policy_optimizer_state,
        target_critic_params=target_critic_params,
        target_policy_params=target_policy_params,
        key=key,
        steps=training_state.steps + 1,
    )

    metrics = {
        "actor_loss": policy_loss,
        "critic_loss": critic_loss,
    }

    return new_training_state, replay_buffer, metrics