Population Based Training (PBT)

PBT is optimization method to jointly optimise a population of models and their hyperparameters to maximize performance.

To use PBT in QDax to train SAC, one can use the two following components (see examples to see how to use the components appropriatly):

qdax.baselines.sac_pbt.PBTSAC (SAC)

Source code in qdax/baselines/sac_pbt.py
class PBTSAC(SAC):
    def __init__(self, config: PBTSacConfig, action_size: int) -> None:

        sac_config = SacConfig(
            batch_size=config.batch_size,
            episode_length=config.episode_length,
            tau=config.tau,
            normalize_observations=config.normalize_observations,
            alpha_init=config.alpha_init,
            policy_hidden_layer_size=config.policy_hidden_layer_size,
            critic_hidden_layer_size=config.critic_hidden_layer_size,
            fix_alpha=config.fix_alpha,
            # unused default values for parameters that will be learnt as part of PBT
            learning_rate=3e-4,
            discount=0.97,
            reward_scaling=1.0,
        )
        SAC.__init__(self, config=sac_config, action_size=action_size)

    def init(
        self, random_key: RNGKey, action_size: int, observation_size: int
    ) -> PBTSacTrainingState:
        """Initialise the training state of the algorithm.

        Args:
            random_key: a jax random key
            action_size: the size of the environment's action space
            observation_size: the size of the environment's observation space

        Returns:
            the initial training state of PBT-SAC
        """

        sac_training_state = SAC.init(self, random_key, action_size, observation_size)

        training_state = PBTSacTrainingState(
            policy_optimizer_state=sac_training_state.policy_optimizer_state,
            policy_params=sac_training_state.policy_params,
            critic_optimizer_state=sac_training_state.critic_optimizer_state,
            critic_params=sac_training_state.critic_params,
            alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
            alpha_params=sac_training_state.alpha_params,
            target_critic_params=sac_training_state.target_critic_params,
            normalization_running_stats=sac_training_state.normalization_running_stats,
            random_key=sac_training_state.random_key,
            steps=sac_training_state.steps,
            discount=None,
            policy_lr=None,
            critic_lr=None,
            alpha_lr=None,
            reward_scaling=None,
        )

        # Sample hyper-params
        training_state = PBTSacTrainingState.resample_hyperparams(training_state)

        return training_state  # type: ignore

    @partial(jax.jit, static_argnames=("self"))
    def update(
        self,
        training_state: PBTSacTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
        """Performs a training step to update the policy and the critic parameters.

        Args:
            training_state: the current PBT-SAC training state
            replay_buffer: the replay buffer

        Returns:
            the updated PBT-SAC training state
            the replay buffer
            the training metrics
        """

        # sample a batch of transitions in the buffer
        random_key = training_state.random_key
        transitions, random_key = replay_buffer.sample(
            random_key,
            sample_size=self._config.batch_size,
        )

        # normalise observations if necessary
        if self._config.normalize_observations:
            normalization_running_stats = training_state.normalization_running_stats
            normalized_obs = normalize_with_rmstd(
                transitions.obs, normalization_running_stats
            )
            normalized_next_obs = normalize_with_rmstd(
                transitions.next_obs, normalization_running_stats
            )
            transitions = transitions.replace(
                obs=normalized_obs, next_obs=normalized_next_obs
            )

        # update alpha
        (
            alpha_params,
            alpha_optimizer_state,
            alpha_loss,
            random_key,
        ) = self._update_alpha(
            alpha_lr=training_state.alpha_lr,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update critic
        (
            critic_params,
            target_critic_params,
            critic_optimizer_state,
            critic_loss,
            random_key,
        ) = self._update_critic(
            critic_lr=training_state.critic_lr,
            reward_scaling=training_state.reward_scaling,
            discount=training_state.discount,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update actor
        (
            policy_params,
            policy_optimizer_state,
            policy_loss,
            random_key,
        ) = self._update_actor(
            policy_lr=training_state.policy_lr,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # create new training state
        new_training_state = PBTSacTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=alpha_params,
            normalization_running_stats=training_state.normalization_running_stats,
            target_critic_params=target_critic_params,
            random_key=random_key,
            steps=training_state.steps + 1,
            discount=training_state.discount,
            policy_lr=training_state.policy_lr,
            critic_lr=training_state.critic_lr,
            alpha_lr=training_state.alpha_lr,
            reward_scaling=training_state.reward_scaling,
        )
        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
            "alpha_loss": alpha_loss,
            "obs_mean": jnp.mean(transitions.obs),
            "obs_std": jnp.std(transitions.obs),
        }
        return new_training_state, replay_buffer, metrics

    def get_init_fn(
        self,
        population_size: int,
        action_size: int,
        observation_size: int,
        buffer_size: int,
    ) -> Callable:
        """
        Returns a function to initialize the population.

        Args:
            population_size: size of the population.
            action_size: action space size.
            observation_size: observation space size.
            buffer_size: replay buffer size.

        Returns:
            a function that takes as input a random key and returns a new random
            key, the PBT population training state and the replay buffers
        """

        def _init_fn(
            random_key: RNGKey,
        ) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]:

            random_key, *keys = jax.random.split(random_key, num=1 + population_size)
            keys = jnp.stack(keys)

            init_dummy_transition = partial(
                Transition.init_dummy,
                observation_dim=observation_size,
                action_dim=action_size,
            )
            init_dummy_transition = jax.vmap(
                init_dummy_transition, axis_size=population_size
            )
            dummy_transitions = init_dummy_transition()

            replay_buffer_init = partial(
                ReplayBuffer.init,
                buffer_size=buffer_size,
            )
            replay_buffer_init = jax.vmap(replay_buffer_init)
            replay_buffers = replay_buffer_init(transition=dummy_transitions)
            agent_init = partial(
                self.init, action_size=action_size, observation_size=observation_size
            )
            training_states = jax.vmap(agent_init)(keys)
            return random_key, training_states, replay_buffers

        return _init_fn

    def get_eval_fn(
        self,
        eval_env: Env,
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns over episodes as well as all returns from all
            agents over all episodes.
        """
        play_eval_step = partial(
            self.play_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_policy_fn,
            play_step_fn=play_eval_step,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_eval_qd_fn(
        self,
        eval_env: Env,
        bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.
            bd_extraction_fn: function to extract the bd from an episode.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns and mean bds over episodes as well as all
            returns and bds from all agents over all episodes.
        """
        play_eval_step = partial(
            self.play_qd_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_qd_policy_fn,
            play_step_fn=play_eval_step,
            bd_extraction_fn=bd_extraction_fn,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_train_fn(
        self,
        env: Env,
        num_iterations: int,
        env_batch_size: int,
        grad_updates_per_step: float,
    ) -> Callable:
        """
        Returns the function to update the population of agents.

        Args:
            env: training environment.
            num_iterations: number of training iterations to perform.
            env_batch_size: number of batched environments.
            grad_updates_per_step: number of gradient to apply per step in the
                environment.

        Returns:
            the function to update the population which takes as input the population
            training state, environment starting states and replay buffers and returns
            updated training states, environment states, replay buffers and metrics.
        """
        play_step = partial(
            self.play_step_fn,
            env=env,
            deterministic=False,
        )

        do_iteration = partial(
            do_iteration_fn,
            env_batch_size=env_batch_size,
            grad_updates_per_step=grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=self.update,
        )

        def _scan_do_iteration(
            carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
            unused_arg: Any,
        ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
            (
                training_state,
                env_state,
                replay_buffer,
                metrics,
            ) = do_iteration(*carry)
            return (training_state, env_state, replay_buffer), metrics

        def train_fn(
            training_state: PBTSacTrainingState,
            env_state: EnvState,
            replay_buffer: ReplayBuffer,
        ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
            (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
                _scan_do_iteration,
                (training_state, env_state, replay_buffer),
                None,
                length=num_iterations,
            )
            return (training_state, env_state, replay_buffer), metrics

        return jax.vmap(train_fn)  # type: ignore

init(self, random_key, action_size, observation_size)

Initialise the training state of the algorithm.

Parameters:
  • random_key (Array) – a jax random key

  • action_size (int) – the size of the environment's action space

  • observation_size (int) – the size of the environment's observation space

Returns:
  • PBTSacTrainingState – the initial training state of PBT-SAC

Source code in qdax/baselines/sac_pbt.py
def init(
    self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTSacTrainingState:
    """Initialise the training state of the algorithm.

    Args:
        random_key: a jax random key
        action_size: the size of the environment's action space
        observation_size: the size of the environment's observation space

    Returns:
        the initial training state of PBT-SAC
    """

    sac_training_state = SAC.init(self, random_key, action_size, observation_size)

    training_state = PBTSacTrainingState(
        policy_optimizer_state=sac_training_state.policy_optimizer_state,
        policy_params=sac_training_state.policy_params,
        critic_optimizer_state=sac_training_state.critic_optimizer_state,
        critic_params=sac_training_state.critic_params,
        alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
        alpha_params=sac_training_state.alpha_params,
        target_critic_params=sac_training_state.target_critic_params,
        normalization_running_stats=sac_training_state.normalization_running_stats,
        random_key=sac_training_state.random_key,
        steps=sac_training_state.steps,
        discount=None,
        policy_lr=None,
        critic_lr=None,
        alpha_lr=None,
        reward_scaling=None,
    )

    # Sample hyper-params
    training_state = PBTSacTrainingState.resample_hyperparams(training_state)

    return training_state  # type: ignore

update(self, training_state, replay_buffer)

Performs a training step to update the policy and the critic parameters.

Parameters:
  • training_state (PBTSacTrainingState) – the current PBT-SAC training state

  • replay_buffer (ReplayBuffer) – the replay buffer

Returns:
  • Tuple[qdax.baselines.sac_pbt.PBTSacTrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer, Dict[str, jax.Array]] – the updated PBT-SAC training state the replay buffer the training metrics

Source code in qdax/baselines/sac_pbt.py
@partial(jax.jit, static_argnames=("self"))
def update(
    self,
    training_state: PBTSacTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
    """Performs a training step to update the policy and the critic parameters.

    Args:
        training_state: the current PBT-SAC training state
        replay_buffer: the replay buffer

    Returns:
        the updated PBT-SAC training state
        the replay buffer
        the training metrics
    """

    # sample a batch of transitions in the buffer
    random_key = training_state.random_key
    transitions, random_key = replay_buffer.sample(
        random_key,
        sample_size=self._config.batch_size,
    )

    # normalise observations if necessary
    if self._config.normalize_observations:
        normalization_running_stats = training_state.normalization_running_stats
        normalized_obs = normalize_with_rmstd(
            transitions.obs, normalization_running_stats
        )
        normalized_next_obs = normalize_with_rmstd(
            transitions.next_obs, normalization_running_stats
        )
        transitions = transitions.replace(
            obs=normalized_obs, next_obs=normalized_next_obs
        )

    # update alpha
    (
        alpha_params,
        alpha_optimizer_state,
        alpha_loss,
        random_key,
    ) = self._update_alpha(
        alpha_lr=training_state.alpha_lr,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # update critic
    (
        critic_params,
        target_critic_params,
        critic_optimizer_state,
        critic_loss,
        random_key,
    ) = self._update_critic(
        critic_lr=training_state.critic_lr,
        reward_scaling=training_state.reward_scaling,
        discount=training_state.discount,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # update actor
    (
        policy_params,
        policy_optimizer_state,
        policy_loss,
        random_key,
    ) = self._update_actor(
        policy_lr=training_state.policy_lr,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # create new training state
    new_training_state = PBTSacTrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_params,
        critic_optimizer_state=critic_optimizer_state,
        critic_params=critic_params,
        alpha_optimizer_state=alpha_optimizer_state,
        alpha_params=alpha_params,
        normalization_running_stats=training_state.normalization_running_stats,
        target_critic_params=target_critic_params,
        random_key=random_key,
        steps=training_state.steps + 1,
        discount=training_state.discount,
        policy_lr=training_state.policy_lr,
        critic_lr=training_state.critic_lr,
        alpha_lr=training_state.alpha_lr,
        reward_scaling=training_state.reward_scaling,
    )
    metrics = {
        "actor_loss": policy_loss,
        "critic_loss": critic_loss,
        "alpha_loss": alpha_loss,
        "obs_mean": jnp.mean(transitions.obs),
        "obs_std": jnp.std(transitions.obs),
    }
    return new_training_state, replay_buffer, metrics

get_init_fn(self, population_size, action_size, observation_size, buffer_size)

Returns a function to initialize the population.

Parameters:
  • population_size (int) – size of the population.

  • action_size (int) – action space size.

  • observation_size (int) – observation space size.

  • buffer_size (int) – replay buffer size.

Returns:
  • Callable – a function that takes as input a random key and returns a new random key, the PBT population training state and the replay buffers

Source code in qdax/baselines/sac_pbt.py
def get_init_fn(
    self,
    population_size: int,
    action_size: int,
    observation_size: int,
    buffer_size: int,
) -> Callable:
    """
    Returns a function to initialize the population.

    Args:
        population_size: size of the population.
        action_size: action space size.
        observation_size: observation space size.
        buffer_size: replay buffer size.

    Returns:
        a function that takes as input a random key and returns a new random
        key, the PBT population training state and the replay buffers
    """

    def _init_fn(
        random_key: RNGKey,
    ) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]:

        random_key, *keys = jax.random.split(random_key, num=1 + population_size)
        keys = jnp.stack(keys)

        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=population_size
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)
        agent_init = partial(
            self.init, action_size=action_size, observation_size=observation_size
        )
        training_states = jax.vmap(agent_init)(keys)
        return random_key, training_states, replay_buffers

    return _init_fn

get_eval_fn(self, eval_env)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) – evaluation environment. Might be different from training env if needed.

Returns:
  • Callable – The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the population agents mean returns over episodes as well as all returns from all agents over all episodes.

Source code in qdax/baselines/sac_pbt.py
def get_eval_fn(
    self,
    eval_env: Env,
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns over episodes as well as all returns from all
        agents over all episodes.
    """
    play_eval_step = partial(
        self.play_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_policy_fn,
        play_step_fn=play_eval_step,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_eval_qd_fn(self, eval_env, bd_extraction_fn)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) – evaluation environment. Might be different from training env if needed.

  • bd_extraction_fn (Callable[[qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array]) – function to extract the bd from an episode.

Returns:
  • Callable – The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the population agents mean returns and mean bds over episodes as well as all returns and bds from all agents over all episodes.

Source code in qdax/baselines/sac_pbt.py
def get_eval_qd_fn(
    self,
    eval_env: Env,
    bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.
        bd_extraction_fn: function to extract the bd from an episode.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns and mean bds over episodes as well as all
        returns and bds from all agents over all episodes.
    """
    play_eval_step = partial(
        self.play_qd_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_qd_policy_fn,
        play_step_fn=play_eval_step,
        bd_extraction_fn=bd_extraction_fn,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_train_fn(self, env, num_iterations, env_batch_size, grad_updates_per_step)

Returns the function to update the population of agents.

Parameters:
  • env (Env) – training environment.

  • num_iterations (int) – number of training iterations to perform.

  • env_batch_size (int) – number of batched environments.

  • grad_updates_per_step (float) – number of gradient to apply per step in the environment.

Returns:
  • Callable – the function to update the population which takes as input the population training state, environment starting states and replay buffers and returns updated training states, environment states, replay buffers and metrics.

Source code in qdax/baselines/sac_pbt.py
def get_train_fn(
    self,
    env: Env,
    num_iterations: int,
    env_batch_size: int,
    grad_updates_per_step: float,
) -> Callable:
    """
    Returns the function to update the population of agents.

    Args:
        env: training environment.
        num_iterations: number of training iterations to perform.
        env_batch_size: number of batched environments.
        grad_updates_per_step: number of gradient to apply per step in the
            environment.

    Returns:
        the function to update the population which takes as input the population
        training state, environment starting states and replay buffers and returns
        updated training states, environment states, replay buffers and metrics.
    """
    play_step = partial(
        self.play_step_fn,
        env=env,
        deterministic=False,
    )

    do_iteration = partial(
        do_iteration_fn,
        env_batch_size=env_batch_size,
        grad_updates_per_step=grad_updates_per_step,
        play_step_fn=play_step,
        update_fn=self.update,
    )

    def _scan_do_iteration(
        carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
        unused_arg: Any,
    ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
        (
            training_state,
            env_state,
            replay_buffer,
            metrics,
        ) = do_iteration(*carry)
        return (training_state, env_state, replay_buffer), metrics

    def train_fn(
        training_state: PBTSacTrainingState,
        env_state: EnvState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
        (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
            _scan_do_iteration,
            (training_state, env_state, replay_buffer),
            None,
            length=num_iterations,
        )
        return (training_state, env_state, replay_buffer), metrics

    return jax.vmap(train_fn)  # type: ignore

and

qdax.baselines.pbt.PBT

This class serves as a template for algorithm that want to implement the standard Population Based Training (PBT) scheme.

Source code in qdax/baselines/pbt.py
class PBT:
    """
    This class serves as a template for algorithm that want to implement the standard
    Population Based Training (PBT) scheme.
    """

    def __init__(
        self,
        population_size: int,
        num_best_to_replace_from: int,
        num_worse_to_replace: int,
    ):
        """

        Args:
            population_size: Size of the PBT population.
            num_best_to_replace_from: Number of top performing individuals to sample
                from when replacing low performers at each iteration.
            num_worse_to_replace: Number of low-performing individuals to replace at
                each iteration.
        """
        if num_best_to_replace_from + num_worse_to_replace > population_size:
            raise ValueError(
                "The sum of best number of individuals to replace "
                "from and worse individuals to replace exceeds the population size."
            )
        self._population_size = population_size
        self._num_best_to_replace_from = num_best_to_replace_from
        self._num_worse_to_replace = num_worse_to_replace

    @partial(jax.jit, static_argnames=("self",))
    def update_states_and_buffer(
        self,
        random_key: RNGKey,
        population_returns: jnp.ndarray,
        training_state: PBTTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
        """
        Updates the agents of the population states as well as
        their shared replay buffer.

        Args:
            random_key: Random RNG key.
            population_returns: Returns of the agents in the populations.
            training_state: The training state of the PBT scheme.
            replay_buffer: Shared replay buffer by the agents.

        Returns:
            Updated random key, updated PBT training state and updated replay buffer.
        """
        indices_sorted = jax.numpy.argsort(-population_returns)
        best_indices = indices_sorted[: self._num_best_to_replace_from]
        indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

        random_key, key = jax.random.split(random_key)
        indices_used_to_replace = jax.random.choice(
            key, best_indices, shape=(self._num_worse_to_replace,), replace=True
        )

        training_state = jax.tree_util.tree_map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            training_state,
            jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
        )

        replay_buffer = jax.tree_util.tree_map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            replay_buffer,
            replay_buffer,
        )

        return random_key, training_state, replay_buffer

    @partial(jax.jit, static_argnames=("self",))
    def update_states_and_buffer_pmap(
        self,
        random_key: RNGKey,
        population_returns: jnp.ndarray,
        training_state: PBTTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
        """
        Updates the agents of the population states as well as
        their shared replay buffer. This is the version of the function to be
        used within jax.pmap. It makes the population is spread over several devices
        and implement a parallel update through communication between the devices.

        Args:
            random_key: Random RNG key.
            population_returns: Returns of the agents in the populations.
            training_state: The training state of the PBT scheme.
            replay_buffer: Shared replay buffer by the agents.

        Returns:
            Updated random key, updated PBT training state and updated replay buffer.
        """
        indices_sorted = jax.numpy.argsort(-population_returns)
        best_indices = indices_sorted[: self._num_best_to_replace_from]
        indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

        best_individuals, best_buffers, best_returns = jax.tree_util.tree_map(
            lambda x: x[best_indices],
            (training_state, replay_buffer, population_returns),
        )
        (
            gathered_best_individuals,
            gathered_best_buffers,
            gathered_best_returns,
        ) = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (best_individuals, best_buffers, best_returns),
        )
        pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
        best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]

        random_key, key = jax.random.split(random_key)
        indices_used_to_replace = jax.random.choice(
            key, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
        )

        training_state = jax.tree_util.tree_map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            training_state,
            jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
                gathered_best_individuals
            ),
        )

        replay_buffer = jax.tree_util.tree_map(
            lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
            replay_buffer,
            gathered_best_buffers,
        )

        return random_key, training_state, replay_buffer

__init__(self, population_size, num_best_to_replace_from, num_worse_to_replace) special

Parameters:
  • population_size (int) – Size of the PBT population.

  • num_best_to_replace_from (int) – Number of top performing individuals to sample from when replacing low performers at each iteration.

  • num_worse_to_replace (int) – Number of low-performing individuals to replace at each iteration.

Source code in qdax/baselines/pbt.py
def __init__(
    self,
    population_size: int,
    num_best_to_replace_from: int,
    num_worse_to_replace: int,
):
    """

    Args:
        population_size: Size of the PBT population.
        num_best_to_replace_from: Number of top performing individuals to sample
            from when replacing low performers at each iteration.
        num_worse_to_replace: Number of low-performing individuals to replace at
            each iteration.
    """
    if num_best_to_replace_from + num_worse_to_replace > population_size:
        raise ValueError(
            "The sum of best number of individuals to replace "
            "from and worse individuals to replace exceeds the population size."
        )
    self._population_size = population_size
    self._num_best_to_replace_from = num_best_to_replace_from
    self._num_worse_to_replace = num_worse_to_replace

update_states_and_buffer(self, random_key, population_returns, training_state, replay_buffer)

Updates the agents of the population states as well as their shared replay buffer.

Parameters:
  • random_key (Array) – Random RNG key.

  • population_returns (Array) – Returns of the agents in the populations.

  • training_state (PBTTrainingState) – The training state of the PBT scheme.

  • replay_buffer (ReplayBuffer) – Shared replay buffer by the agents.

Returns:
  • Tuple[jax.Array, qdax.baselines.pbt.PBTTrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer] – Updated random key, updated PBT training state and updated replay buffer.

Source code in qdax/baselines/pbt.py
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer(
    self,
    random_key: RNGKey,
    population_returns: jnp.ndarray,
    training_state: PBTTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
    """
    Updates the agents of the population states as well as
    their shared replay buffer.

    Args:
        random_key: Random RNG key.
        population_returns: Returns of the agents in the populations.
        training_state: The training state of the PBT scheme.
        replay_buffer: Shared replay buffer by the agents.

    Returns:
        Updated random key, updated PBT training state and updated replay buffer.
    """
    indices_sorted = jax.numpy.argsort(-population_returns)
    best_indices = indices_sorted[: self._num_best_to_replace_from]
    indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

    random_key, key = jax.random.split(random_key)
    indices_used_to_replace = jax.random.choice(
        key, best_indices, shape=(self._num_worse_to_replace,), replace=True
    )

    training_state = jax.tree_util.tree_map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        training_state,
        jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
    )

    replay_buffer = jax.tree_util.tree_map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        replay_buffer,
        replay_buffer,
    )

    return random_key, training_state, replay_buffer

update_states_and_buffer_pmap(self, random_key, population_returns, training_state, replay_buffer)

Updates the agents of the population states as well as their shared replay buffer. This is the version of the function to be used within jax.pmap. It makes the population is spread over several devices and implement a parallel update through communication between the devices.

Parameters:
  • random_key (Array) – Random RNG key.

  • population_returns (Array) – Returns of the agents in the populations.

  • training_state (PBTTrainingState) – The training state of the PBT scheme.

  • replay_buffer (ReplayBuffer) – Shared replay buffer by the agents.

Returns:
  • Tuple[jax.Array, qdax.baselines.pbt.PBTTrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer] – Updated random key, updated PBT training state and updated replay buffer.

Source code in qdax/baselines/pbt.py
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer_pmap(
    self,
    random_key: RNGKey,
    population_returns: jnp.ndarray,
    training_state: PBTTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
    """
    Updates the agents of the population states as well as
    their shared replay buffer. This is the version of the function to be
    used within jax.pmap. It makes the population is spread over several devices
    and implement a parallel update through communication between the devices.

    Args:
        random_key: Random RNG key.
        population_returns: Returns of the agents in the populations.
        training_state: The training state of the PBT scheme.
        replay_buffer: Shared replay buffer by the agents.

    Returns:
        Updated random key, updated PBT training state and updated replay buffer.
    """
    indices_sorted = jax.numpy.argsort(-population_returns)
    best_indices = indices_sorted[: self._num_best_to_replace_from]
    indices_to_replace = indices_sorted[-self._num_worse_to_replace :]

    best_individuals, best_buffers, best_returns = jax.tree_util.tree_map(
        lambda x: x[best_indices],
        (training_state, replay_buffer, population_returns),
    )
    (
        gathered_best_individuals,
        gathered_best_buffers,
        gathered_best_returns,
    ) = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (best_individuals, best_buffers, best_returns),
    )
    pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
    best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]

    random_key, key = jax.random.split(random_key)
    indices_used_to_replace = jax.random.choice(
        key, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
    )

    training_state = jax.tree_util.tree_map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        training_state,
        jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
            gathered_best_individuals
        ),
    )

    replay_buffer = jax.tree_util.tree_map(
        lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
        replay_buffer,
        gathered_best_buffers,
    )

    return random_key, training_state, replay_buffer

To use PBT in order to train TD3 agents, please use the PBTTD3 class:

qdax.baselines.td3_pbt.PBTTD3 (TD3)

Source code in qdax/baselines/td3_pbt.py
class PBTTD3(TD3):
    def __init__(self, config: PBTTD3Config, action_size: int):

        td3_config = TD3Config(
            episode_length=config.episode_length,
            batch_size=config.batch_size,
            policy_delay=config.policy_delay,
            reward_scaling=config.reward_scaling,
            soft_tau_update=config.soft_tau_update,
            critic_hidden_layer_size=config.critic_hidden_layer_size,
            policy_hidden_layer_size=config.policy_hidden_layer_size,
        )
        TD3.__init__(self, td3_config, action_size)

    def init(
        self, random_key: RNGKey, action_size: int, observation_size: int
    ) -> PBTTD3TrainingState:
        """Initialise the training state of the PBT-TD3 algorithm, through creation
        of optimizer states and params.

        Args:
            random_key: a random key used for random operations.
            action_size: the size of the action array needed to interact with the
                environment.
            observation_size: the size of the observation array retrieved from the
                environment.

        Returns:
            the initial training state.
        """

        training_state = TD3.init(self, random_key, action_size, observation_size)

        # Initial training state
        training_state = PBTTD3TrainingState(
            policy_optimizer_state=training_state.policy_optimizer_state,
            policy_params=training_state.policy_params,
            critic_optimizer_state=training_state.critic_optimizer_state,
            critic_params=training_state.critic_params,
            target_policy_params=training_state.target_policy_params,
            target_critic_params=training_state.target_critic_params,
            random_key=training_state.random_key,
            steps=training_state.steps,
            discount=None,
            policy_lr=None,
            critic_lr=None,
            noise_clip=None,
            policy_noise=None,
            expl_noise=None,
        )

        # Sample hyperparameters
        training_state = PBTTD3TrainingState.resample_hyperparams(training_state)

        return training_state  # type: ignore

    @partial(jax.jit, static_argnames=("self", "env", "deterministic"))
    def play_step_fn(
        self,
        env_state: EnvState,
        training_state: TD3TrainingState,
        env: Env,
        deterministic: bool = False,
    ) -> Tuple[EnvState, TD3TrainingState, Transition]:
        """Plays a step in the environment. Selects an action according to TD3 rule and
        performs the environment step.

        Args:
            env_state: the current environment state
            training_state: the PBT-TD3 training state
            env: the environment
            deterministic: whether to select action in a deterministic way.
                Defaults to False.

        Returns:
            the new environment state
            the new PBT-TD3 training state
            the played transition
        """

        actions, random_key = self.select_action(
            obs=env_state.obs,
            policy_params=training_state.policy_params,
            random_key=training_state.random_key,
            expl_noise=training_state.expl_noise,
            deterministic=deterministic,
        )
        training_state = training_state.replace(
            random_key=random_key,
        )
        next_env_state = env.step(env_state, actions)
        transition = Transition(
            obs=env_state.obs,
            next_obs=next_env_state.obs,
            rewards=next_env_state.reward,
            dones=next_env_state.done,
            truncations=next_env_state.info["truncation"],
            actions=actions,
        )
        return next_env_state, training_state, transition

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        training_state: PBTTD3TrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
        """Performs a single training step: updates policy params and critic params
        through gradient descent.

        Args:
            training_state: the current training state, containing the optimizer states
                and the params of the policy and critic.
            replay_buffer: the replay buffer, filled with transitions experienced in
                the environment.

        Returns:
            A new training state, the buffer with new transitions and metrics about the
            training process.
        """

        # Sample a batch of transitions in the buffer
        random_key = training_state.random_key
        samples, random_key = replay_buffer.sample(
            random_key, sample_size=self._config.batch_size
        )

        # Update Critic
        random_key, subkey = jax.random.split(random_key)
        critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
            training_state.critic_params,
            target_policy_params=training_state.target_policy_params,
            target_critic_params=training_state.target_critic_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            policy_noise=training_state.policy_noise,
            noise_clip=training_state.noise_clip,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            transitions=samples,
            random_key=subkey,
        )
        critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
        critic_updates, critic_optimizer_state = critic_optimizer.update(
            critic_gradient, training_state.critic_optimizer_state
        )
        critic_params = optax.apply_updates(
            training_state.critic_params, critic_updates
        )
        # Soft update of target critic network
        target_critic_params = jax.tree_util.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            training_state.target_critic_params,
            critic_params,
        )

        # Update policy
        policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
            training_state.policy_params,
            critic_params=training_state.critic_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            transitions=samples,
        )

        def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
            policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
            (policy_updates, policy_optimizer_state,) = policy_optimizer.update(
                policy_gradient, training_state.policy_optimizer_state
            )
            policy_params = optax.apply_updates(
                training_state.policy_params, policy_updates
            )
            # Soft update of target policy
            target_policy_params = jax.tree_util.tree_map(
                lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
                + self._config.soft_tau_update * x2,
                training_state.target_policy_params,
                policy_params,
            )
            return policy_params, target_policy_params, policy_optimizer_state

        # Delayed update
        current_policy_state = (
            training_state.policy_params,
            training_state.target_policy_params,
            training_state.policy_optimizer_state,
        )
        policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
            training_state.steps % self._config.policy_delay == 0,
            lambda _: update_policy_step(),
            lambda _: current_policy_state,
            operand=None,
        )

        # Create new training state
        new_training_state = training_state.replace(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            policy_params=policy_params,
            policy_optimizer_state=policy_optimizer_state,
            target_critic_params=target_critic_params,
            target_policy_params=target_policy_params,
            random_key=random_key,
            steps=training_state.steps + 1,
        )

        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
        }

        return new_training_state, replay_buffer, metrics

    def get_init_fn(
        self,
        population_size: int,
        action_size: int,
        observation_size: int,
        buffer_size: int,
    ) -> Callable:
        """
        Returns a function to initialize the population.

        Args:
            population_size: size of the population.
            action_size: action space size.
            observation_size: observation space size.
            buffer_size: replay buffer size.

        Returns:
            a function that takes as input a random key and returns a new random
            key, the PBT population training state and the replay buffers
        """

        def _init_fn(
            random_key: RNGKey,
        ) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]:
            random_key, *keys = jax.random.split(random_key, num=1 + population_size)
            keys = jnp.stack(keys)

            init_dummy_transition = partial(
                Transition.init_dummy,
                observation_dim=observation_size,
                action_dim=action_size,
            )
            init_dummy_transition = jax.vmap(
                init_dummy_transition, axis_size=population_size
            )
            dummy_transitions = init_dummy_transition()

            replay_buffer_init = partial(
                ReplayBuffer.init,
                buffer_size=buffer_size,
            )
            replay_buffer_init = jax.vmap(replay_buffer_init)
            replay_buffers = replay_buffer_init(transition=dummy_transitions)
            agent_init = partial(
                self.init, action_size=action_size, observation_size=observation_size
            )
            training_states = jax.vmap(agent_init)(keys)
            return random_key, training_states, replay_buffers

        return _init_fn

    def get_eval_fn(
        self,
        eval_env: Env,
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns over episodes as well as all returns from all
            agents over all episodes.
        """
        play_eval_step = partial(
            self.play_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_policy_fn,
            play_step_fn=play_eval_step,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_eval_qd_fn(
        self,
        eval_env: Env,
        bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
    ) -> Callable:
        """
        Returns the function the evaluation the PBT population.

        Args:
            eval_env: evaluation environment. Might be different from training env
                if needed.
            bd_extraction_fn: function to extract the bd from an episode.

        Returns:
            The function to evaluate the population. It takes as input the population
            training state as well as first eval environment states and returns the
            population agents mean returns and mean bds over episodes as well as all
            returns and bds from all agents over all episodes.
        """
        play_eval_step = partial(
            self.play_qd_step_fn,
            env=eval_env,
            deterministic=True,
        )

        eval_policy = partial(
            self.eval_qd_policy_fn,
            play_step_fn=play_eval_step,
            bd_extraction_fn=bd_extraction_fn,
        )
        return jax.vmap(eval_policy)  # type: ignore

    def get_train_fn(
        self,
        env: Env,
        num_iterations: int,
        env_batch_size: int,
        grad_updates_per_step: float,
    ) -> Callable:
        """
        Returns the function to update the population of agents.

        Args:
            env: training environment.
            num_iterations: number of training iterations to perform.
            env_batch_size: number of batched environments.
            grad_updates_per_step: number of gradient to apply per step in the
                environment.

        Returns:
            the function to update the population which takes as input the population
            training state, environment starting states and replay buffers and returns
            updated training states, environment states, replay buffers and metrics.
        """
        play_step = partial(
            self.play_step_fn,
            env=env,
            deterministic=False,
        )

        do_iteration = partial(
            do_iteration_fn,
            env_batch_size=env_batch_size,
            grad_updates_per_step=grad_updates_per_step,
            play_step_fn=play_step,
            update_fn=self.update,
        )

        def _scan_do_iteration(
            carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
            unused_arg: Any,
        ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
            (
                training_state,
                env_state,
                replay_buffer,
                metrics,
            ) = do_iteration(*carry)
            return (training_state, env_state, replay_buffer), metrics

        def train_fn(
            training_state: PBTTD3TrainingState,
            env_state: EnvState,
            replay_buffer: ReplayBuffer,
        ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
            (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
                _scan_do_iteration,
                (training_state, env_state, replay_buffer),
                None,
                length=num_iterations,
            )
            return (training_state, env_state, replay_buffer), metrics

        return jax.vmap(train_fn)  # type: ignore

init(self, random_key, action_size, observation_size)

Initialise the training state of the PBT-TD3 algorithm, through creation of optimizer states and params.

Parameters:
  • random_key (Array) – a random key used for random operations.

  • action_size (int) – the size of the action array needed to interact with the environment.

  • observation_size (int) – the size of the observation array retrieved from the environment.

Returns:
  • PBTTD3TrainingState – the initial training state.

Source code in qdax/baselines/td3_pbt.py
def init(
    self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTTD3TrainingState:
    """Initialise the training state of the PBT-TD3 algorithm, through creation
    of optimizer states and params.

    Args:
        random_key: a random key used for random operations.
        action_size: the size of the action array needed to interact with the
            environment.
        observation_size: the size of the observation array retrieved from the
            environment.

    Returns:
        the initial training state.
    """

    training_state = TD3.init(self, random_key, action_size, observation_size)

    # Initial training state
    training_state = PBTTD3TrainingState(
        policy_optimizer_state=training_state.policy_optimizer_state,
        policy_params=training_state.policy_params,
        critic_optimizer_state=training_state.critic_optimizer_state,
        critic_params=training_state.critic_params,
        target_policy_params=training_state.target_policy_params,
        target_critic_params=training_state.target_critic_params,
        random_key=training_state.random_key,
        steps=training_state.steps,
        discount=None,
        policy_lr=None,
        critic_lr=None,
        noise_clip=None,
        policy_noise=None,
        expl_noise=None,
    )

    # Sample hyperparameters
    training_state = PBTTD3TrainingState.resample_hyperparams(training_state)

    return training_state  # type: ignore

play_step_fn(self, env_state, training_state, env, deterministic=False)

Plays a step in the environment. Selects an action according to TD3 rule and performs the environment step.

Parameters:
  • env_state (State) – the current environment state

  • training_state (TD3TrainingState) – the PBT-TD3 training state

  • env (Env) – the environment

  • deterministic (bool) – whether to select action in a deterministic way. Defaults to False.

Returns:
  • Tuple[brax.envs.base.State, qdax.baselines.td3.TD3TrainingState, qdax.core.neuroevolution.buffers.buffer.Transition] – the new environment state the new PBT-TD3 training state the played transition

Source code in qdax/baselines/td3_pbt.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic"))
def play_step_fn(
    self,
    env_state: EnvState,
    training_state: TD3TrainingState,
    env: Env,
    deterministic: bool = False,
) -> Tuple[EnvState, TD3TrainingState, Transition]:
    """Plays a step in the environment. Selects an action according to TD3 rule and
    performs the environment step.

    Args:
        env_state: the current environment state
        training_state: the PBT-TD3 training state
        env: the environment
        deterministic: whether to select action in a deterministic way.
            Defaults to False.

    Returns:
        the new environment state
        the new PBT-TD3 training state
        the played transition
    """

    actions, random_key = self.select_action(
        obs=env_state.obs,
        policy_params=training_state.policy_params,
        random_key=training_state.random_key,
        expl_noise=training_state.expl_noise,
        deterministic=deterministic,
    )
    training_state = training_state.replace(
        random_key=random_key,
    )
    next_env_state = env.step(env_state, actions)
    transition = Transition(
        obs=env_state.obs,
        next_obs=next_env_state.obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        truncations=next_env_state.info["truncation"],
        actions=actions,
    )
    return next_env_state, training_state, transition

update(self, training_state, replay_buffer)

Performs a single training step: updates policy params and critic params through gradient descent.

Parameters:
  • training_state (PBTTD3TrainingState) – the current training state, containing the optimizer states and the params of the policy and critic.

  • replay_buffer (ReplayBuffer) – the replay buffer, filled with transitions experienced in the environment.

Returns:
  • Tuple[qdax.baselines.td3_pbt.PBTTD3TrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer, Dict[str, jax.Array]] – A new training state, the buffer with new transitions and metrics about the training process.

Source code in qdax/baselines/td3_pbt.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    training_state: PBTTD3TrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
    """Performs a single training step: updates policy params and critic params
    through gradient descent.

    Args:
        training_state: the current training state, containing the optimizer states
            and the params of the policy and critic.
        replay_buffer: the replay buffer, filled with transitions experienced in
            the environment.

    Returns:
        A new training state, the buffer with new transitions and metrics about the
        training process.
    """

    # Sample a batch of transitions in the buffer
    random_key = training_state.random_key
    samples, random_key = replay_buffer.sample(
        random_key, sample_size=self._config.batch_size
    )

    # Update Critic
    random_key, subkey = jax.random.split(random_key)
    critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
        training_state.critic_params,
        target_policy_params=training_state.target_policy_params,
        target_critic_params=training_state.target_critic_params,
        policy_fn=self._policy.apply,
        critic_fn=self._critic.apply,
        policy_noise=training_state.policy_noise,
        noise_clip=training_state.noise_clip,
        reward_scaling=self._config.reward_scaling,
        discount=self._config.discount,
        transitions=samples,
        random_key=subkey,
    )
    critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
    critic_updates, critic_optimizer_state = critic_optimizer.update(
        critic_gradient, training_state.critic_optimizer_state
    )
    critic_params = optax.apply_updates(
        training_state.critic_params, critic_updates
    )
    # Soft update of target critic network
    target_critic_params = jax.tree_util.tree_map(
        lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
        + self._config.soft_tau_update * x2,
        training_state.target_critic_params,
        critic_params,
    )

    # Update policy
    policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
        training_state.policy_params,
        critic_params=training_state.critic_params,
        policy_fn=self._policy.apply,
        critic_fn=self._critic.apply,
        transitions=samples,
    )

    def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
        policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
        (policy_updates, policy_optimizer_state,) = policy_optimizer.update(
            policy_gradient, training_state.policy_optimizer_state
        )
        policy_params = optax.apply_updates(
            training_state.policy_params, policy_updates
        )
        # Soft update of target policy
        target_policy_params = jax.tree_util.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            training_state.target_policy_params,
            policy_params,
        )
        return policy_params, target_policy_params, policy_optimizer_state

    # Delayed update
    current_policy_state = (
        training_state.policy_params,
        training_state.target_policy_params,
        training_state.policy_optimizer_state,
    )
    policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
        training_state.steps % self._config.policy_delay == 0,
        lambda _: update_policy_step(),
        lambda _: current_policy_state,
        operand=None,
    )

    # Create new training state
    new_training_state = training_state.replace(
        critic_params=critic_params,
        critic_optimizer_state=critic_optimizer_state,
        policy_params=policy_params,
        policy_optimizer_state=policy_optimizer_state,
        target_critic_params=target_critic_params,
        target_policy_params=target_policy_params,
        random_key=random_key,
        steps=training_state.steps + 1,
    )

    metrics = {
        "actor_loss": policy_loss,
        "critic_loss": critic_loss,
    }

    return new_training_state, replay_buffer, metrics

get_init_fn(self, population_size, action_size, observation_size, buffer_size)

Returns a function to initialize the population.

Parameters:
  • population_size (int) – size of the population.

  • action_size (int) – action space size.

  • observation_size (int) – observation space size.

  • buffer_size (int) – replay buffer size.

Returns:
  • Callable – a function that takes as input a random key and returns a new random key, the PBT population training state and the replay buffers

Source code in qdax/baselines/td3_pbt.py
def get_init_fn(
    self,
    population_size: int,
    action_size: int,
    observation_size: int,
    buffer_size: int,
) -> Callable:
    """
    Returns a function to initialize the population.

    Args:
        population_size: size of the population.
        action_size: action space size.
        observation_size: observation space size.
        buffer_size: replay buffer size.

    Returns:
        a function that takes as input a random key and returns a new random
        key, the PBT population training state and the replay buffers
    """

    def _init_fn(
        random_key: RNGKey,
    ) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]:
        random_key, *keys = jax.random.split(random_key, num=1 + population_size)
        keys = jnp.stack(keys)

        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=population_size
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)
        agent_init = partial(
            self.init, action_size=action_size, observation_size=observation_size
        )
        training_states = jax.vmap(agent_init)(keys)
        return random_key, training_states, replay_buffers

    return _init_fn

get_eval_fn(self, eval_env)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) – evaluation environment. Might be different from training env if needed.

Returns:
  • Callable – The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the population agents mean returns over episodes as well as all returns from all agents over all episodes.

Source code in qdax/baselines/td3_pbt.py
def get_eval_fn(
    self,
    eval_env: Env,
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns over episodes as well as all returns from all
        agents over all episodes.
    """
    play_eval_step = partial(
        self.play_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_policy_fn,
        play_step_fn=play_eval_step,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_eval_qd_fn(self, eval_env, bd_extraction_fn)

Returns the function the evaluation the PBT population.

Parameters:
  • eval_env (Env) – evaluation environment. Might be different from training env if needed.

  • bd_extraction_fn (Callable[[qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array]) – function to extract the bd from an episode.

Returns:
  • Callable – The function to evaluate the population. It takes as input the population training state as well as first eval environment states and returns the population agents mean returns and mean bds over episodes as well as all returns and bds from all agents over all episodes.

Source code in qdax/baselines/td3_pbt.py
def get_eval_qd_fn(
    self,
    eval_env: Env,
    bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
    """
    Returns the function the evaluation the PBT population.

    Args:
        eval_env: evaluation environment. Might be different from training env
            if needed.
        bd_extraction_fn: function to extract the bd from an episode.

    Returns:
        The function to evaluate the population. It takes as input the population
        training state as well as first eval environment states and returns the
        population agents mean returns and mean bds over episodes as well as all
        returns and bds from all agents over all episodes.
    """
    play_eval_step = partial(
        self.play_qd_step_fn,
        env=eval_env,
        deterministic=True,
    )

    eval_policy = partial(
        self.eval_qd_policy_fn,
        play_step_fn=play_eval_step,
        bd_extraction_fn=bd_extraction_fn,
    )
    return jax.vmap(eval_policy)  # type: ignore

get_train_fn(self, env, num_iterations, env_batch_size, grad_updates_per_step)

Returns the function to update the population of agents.

Parameters:
  • env (Env) – training environment.

  • num_iterations (int) – number of training iterations to perform.

  • env_batch_size (int) – number of batched environments.

  • grad_updates_per_step (float) – number of gradient to apply per step in the environment.

Returns:
  • Callable – the function to update the population which takes as input the population training state, environment starting states and replay buffers and returns updated training states, environment states, replay buffers and metrics.

Source code in qdax/baselines/td3_pbt.py
def get_train_fn(
    self,
    env: Env,
    num_iterations: int,
    env_batch_size: int,
    grad_updates_per_step: float,
) -> Callable:
    """
    Returns the function to update the population of agents.

    Args:
        env: training environment.
        num_iterations: number of training iterations to perform.
        env_batch_size: number of batched environments.
        grad_updates_per_step: number of gradient to apply per step in the
            environment.

    Returns:
        the function to update the population which takes as input the population
        training state, environment starting states and replay buffers and returns
        updated training states, environment states, replay buffers and metrics.
    """
    play_step = partial(
        self.play_step_fn,
        env=env,
        deterministic=False,
    )

    do_iteration = partial(
        do_iteration_fn,
        env_batch_size=env_batch_size,
        grad_updates_per_step=grad_updates_per_step,
        play_step_fn=play_step,
        update_fn=self.update,
    )

    def _scan_do_iteration(
        carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
        unused_arg: Any,
    ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
        (
            training_state,
            env_state,
            replay_buffer,
            metrics,
        ) = do_iteration(*carry)
        return (training_state, env_state, replay_buffer), metrics

    def train_fn(
        training_state: PBTTD3TrainingState,
        env_state: EnvState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
        (training_state, env_state, replay_buffer), metrics = jax.lax.scan(
            _scan_do_iteration,
            (training_state, env_state, replay_buffer),
            None,
            length=num_iterations,
        )
        return (training_state, env_state, replay_buffer), metrics

    return jax.vmap(train_fn)  # type: ignore