MAP Elites Population Based Training (ME PBT)

ME PBT is a recent algorithm combining MAP Elites with Population Based Training to evolve a population of diverse RL agents.

To create an instance of PBTME, one need to use an instance of Distributed MAP-Elites with the PBTEmitter, detailed below.

qdax.core.emitters.pbt_me_emitter.PBTEmitter (Emitter)

A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.

Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitter(Emitter):
    """
    A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
    (PGA-Map-Elites) algorithm.
    """

    def __init__(
        self,
        pbt_agent: Union[PBTSAC, PBTTD3],
        config: PBTEmitterConfig,
        env: QDEnv,
        variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
    ) -> None:

        # Parameters internalization
        self._env = env
        self._variation_fn = variation_fn
        self._config = config
        self._agent = pbt_agent
        self._train_fn = self._agent.get_train_fn(
            env=env,
            num_iterations=config.num_training_iterations,
            env_batch_size=config.env_batch_size,
            grad_updates_per_step=config.grad_updates_per_step,
        )

        # Compute numbers from fractions
        pg_population_size = config.pg_population_size_per_device * config.num_devices
        self._num_best_to_replace_from = int(
            pg_population_size * config.fraction_best_to_replace_from
        )
        self._num_to_replace_from_best = int(
            pg_population_size * config.fraction_to_replace_from_best
        )
        self._num_to_replace_from_samples = int(
            pg_population_size * config.fraction_to_replace_from_samples
        )
        self._num_to_exchange = int(
            config.pg_population_size_per_device * config.fraction_sort_exchange
        )

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[PBTEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the PGAMEEmitter, a new random key.
        """

        observation_size = self._env.observation_size
        action_size = self._env.action_size

        # Initialise replay buffers
        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=self._config.pg_population_size_per_device
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=self._config.buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)

        # Initialise env states
        (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
        env_states = jax.jit(self._env.reset)(rng=subkey1)

        reshape_fn = jax.jit(
            lambda tree: jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x,
                    (
                        self._config.pg_population_size_per_device,
                        self._config.env_batch_size,
                    )
                    + x.shape[1:],
                ),
                tree,
            ),
        )
        env_states = reshape_fn(env_states)

        # Create emitter state
        # keep only pg population size training states if more are provided
        init_genotypes = jax.tree_util.tree_map(
            lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
        )
        emitter_state = PBTEmitterState(
            replay_buffers=replay_buffers,
            env_states=env_states,
            training_states=init_genotypes,
            random_key=subkey2,
        )

        return emitter_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: PBTEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Do a single PGA-ME iteration: train critics and greedy policy,
        make mutations (evo and pg), score solution, fill replay buffer and insert back
        in the MAP-Elites grid.

        Args:
            repertoire: the current repertoire of genotypes
            emitter_state: the state of the emitter used
            random_key: a random key

        Returns:
            A batch of offspring, the new emitter state and a new key.
        """

        # Mutation PG (the mutation has already been performed during the state update)
        x_mutation_pg = emitter_state.training_states

        # Mutation evo
        if self._config.ga_population_size_per_device > 0:
            mutation_ga_batch_size = self._config.ga_population_size_per_device
            x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
            x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
            x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)

            # Gather offspring
            genotypes = jax.tree_util.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                x_mutation_ga,
                x_mutation_pg,
            )
        else:
            genotypes = x_mutation_pg

        return genotypes, random_key

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        mutation_pg_batch_size = self._config.pg_population_size_per_device
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        return mutation_pg_batch_size + mutation_ga_batch_size

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: PBTEmitterState,
        repertoire: Repertoire,
        genotypes: Optional[Genotype],
        fitnesses: Fitness,
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> PBTEmitterState:
        """
        Update the internal emitter state. I.e. update the population replay buffers and
        agents.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # Look only at the fitness corresponding to emitter state individuals
        fitnesses = fitnesses[self._config.ga_population_size_per_device :]
        fitnesses = jnp.ravel(fitnesses)
        training_states = emitter_state.training_states
        replay_buffers = emitter_state.replay_buffers
        genotypes = (training_states, replay_buffers)

        # Incremental algorithm to gather top best among the population on each device
        # First exchange
        indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
        num_best_local = int(
            self._config.pg_population_size_per_device
            * self._config.fraction_best_to_replace_from
        )
        indices_to_share = indices_to_share[:num_best_local]
        genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
        best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
            lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
        )

        # Define loop fn for the other exchanges
        def _loop_fn(i, val):  # type: ignore
            best_genotypes_local, best_fitnesses_local = val
            indices_to_share = jax.lax.dynamic_slice(
                jnp.arange(self._config.pg_population_size_per_device),
                [i * self._num_to_exchange],
                [self._num_to_exchange],
            )
            genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
                lambda x: x[indices_to_share], (genotypes, fitnesses)
            )
            gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
                lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
                (genotypes_to_share, fitnesses_to_share),
            )

            genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                (gathered_genotypes, gathered_fitnesses),
                (best_genotypes_local, best_fitnesses_local),
            )

            best_indices_stacked = jnp.argsort(-fitnesses_stacked)
            best_indices_stacked = best_indices_stacked[
                : self._num_best_to_replace_from
            ]
            best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
                lambda x: x[best_indices_stacked],
                (genotypes_stacked, fitnesses_stacked),
            )
            return (best_genotypes_local, best_fitnesses_local)  # type: ignore

        # Incrementally get the top fraction_best_to_replace_from best individuals
        # on each device
        (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
            lower=1,
            upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
            body_fun=_loop_fn,
            init_val=(best_genotypes_local, best_fitnesses_local),
        )

        # Gather fitnesses from all devices to rank locally against it
        all_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            fitnesses,
        )
        all_fitnesses = jnp.ravel(all_fitnesses)
        all_fitnesses = -jnp.sort(-all_fitnesses)
        random_key = emitter_state.random_key
        random_key, sub_key = jax.random.split(random_key)
        best_genotypes = jax.tree_util.tree_map(
            lambda x: jax.random.choice(
                sub_key, x, shape=(len(fitnesses),), replace=True
            ),
            best_genotypes_local,
        )
        best_training_states, best_replay_buffers = best_genotypes

        # Resample hyper-params
        best_training_states = jax.vmap(
            best_training_states.__class__.resample_hyperparams
        )(best_training_states)

        # Replace by individuals from the best
        lower_bound = all_fitnesses[-self._num_to_replace_from_best]
        cond = fitnesses <= lower_bound

        training_states = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_training_states,
            training_states,
        )
        replay_buffers = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_replay_buffers,
            replay_buffers,
        )

        # Replacing with samples from the ME repertoire
        if self._num_to_replace_from_samples > 0:
            me_samples, random_key = repertoire.sample(
                random_key, self._config.pg_population_size_per_device
            )
            # Resample hyper-params
            me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
            upper_bound = all_fitnesses[
                -self._num_to_replace_from_best - self._num_to_replace_from_samples
            ]
            cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
            training_states = jax.tree_util.tree_map(
                lambda x, y: jnp.where(
                    jnp.expand_dims(
                        cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                    ),
                    x,
                    y,
                ),
                me_samples,
                training_states,
            )

        # Train the agents
        env_states = emitter_state.env_states
        # Init optimizers state before training the population
        training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
            training_states
        )
        (training_states, env_states, replay_buffers), metrics = self._train_fn(
            training_states, env_states, replay_buffers
        )
        # Empty optimizers states to avoid storing the info in the RAM
        # and having too heavy repertoires
        training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
            training_states
        )

        # Update emitter state
        emitter_state = emitter_state.replace(
            training_states=training_states,
            replay_buffers=replay_buffers,
            env_states=env_states,
            random_key=random_key,
        )
        return emitter_state  # type: ignore

batch_size: int property readonly

Returns:
  • int – the batch size emitted by the emitter.

init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The initial population.

  • random_key (Array) – A random key.

Returns:
  • Tuple[qdax.core.emitters.pbt_me_emitter.PBTEmitterState, jax.Array] – The initial state of the PGAMEEmitter, a new random key.

Source code in qdax/core/emitters/pbt_me_emitter.py
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[PBTEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the PGAMEEmitter, a new random key.
    """

    observation_size = self._env.observation_size
    action_size = self._env.action_size

    # Initialise replay buffers
    init_dummy_transition = partial(
        Transition.init_dummy,
        observation_dim=observation_size,
        action_dim=action_size,
    )
    init_dummy_transition = jax.vmap(
        init_dummy_transition, axis_size=self._config.pg_population_size_per_device
    )
    dummy_transitions = init_dummy_transition()

    replay_buffer_init = partial(
        ReplayBuffer.init,
        buffer_size=self._config.buffer_size,
    )
    replay_buffer_init = jax.vmap(replay_buffer_init)
    replay_buffers = replay_buffer_init(transition=dummy_transitions)

    # Initialise env states
    (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
    env_states = jax.jit(self._env.reset)(rng=subkey1)

    reshape_fn = jax.jit(
        lambda tree: jax.tree_util.tree_map(
            lambda x: jnp.reshape(
                x,
                (
                    self._config.pg_population_size_per_device,
                    self._config.env_batch_size,
                )
                + x.shape[1:],
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)

    # Create emitter state
    # keep only pg population size training states if more are provided
    init_genotypes = jax.tree_util.tree_map(
        lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
    )
    emitter_state = PBTEmitterState(
        replay_buffers=replay_buffers,
        env_states=env_states,
        training_states=init_genotypes,
        random_key=subkey2,
    )

    return emitter_state, random_key

emit(self, repertoire, emitter_state, random_key)

Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid.

Parameters:
  • repertoire (Repertoire) – the current repertoire of genotypes

  • emitter_state (PBTEmitterState) – the state of the emitter used

  • random_key (Array) – a random key

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A batch of offspring, the new emitter state and a new key.

Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: Repertoire,
    emitter_state: PBTEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Do a single PGA-ME iteration: train critics and greedy policy,
    make mutations (evo and pg), score solution, fill replay buffer and insert back
    in the MAP-Elites grid.

    Args:
        repertoire: the current repertoire of genotypes
        emitter_state: the state of the emitter used
        random_key: a random key

    Returns:
        A batch of offspring, the new emitter state and a new key.
    """

    # Mutation PG (the mutation has already been performed during the state update)
    x_mutation_pg = emitter_state.training_states

    # Mutation evo
    if self._config.ga_population_size_per_device > 0:
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
        x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
        x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)

        # Gather offspring
        genotypes = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            x_mutation_ga,
            x_mutation_pg,
        )
    else:
        genotypes = x_mutation_pg

    return genotypes, random_key

state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Update the internal emitter state. I.e. update the population replay buffers and agents.

Parameters:
  • emitter_state (PBTEmitterState) – current emitter state.

  • repertoire (Repertoire) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – unused here - but compulsory in the signature.

  • fitnesses (Array) – unused here - but compulsory in the signature.

  • descriptors (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – extra information coming from the scoring function, this contains the transitions added to the replay buffer.

Returns:
  • PBTEmitterState – New emitter state where the replay buffer has been filled with the new experienced transitions.

Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
    self,
    emitter_state: PBTEmitterState,
    repertoire: Repertoire,
    genotypes: Optional[Genotype],
    fitnesses: Fitness,
    descriptors: Optional[Descriptor],
    extra_scores: ExtraScores,
) -> PBTEmitterState:
    """
    Update the internal emitter state. I.e. update the population replay buffers and
    agents.

    Args:
        emitter_state: current emitter state.
        repertoire: the current genotypes repertoire
        genotypes: unused here - but compulsory in the signature.
        fitnesses: unused here - but compulsory in the signature.
        descriptors: unused here - but compulsory in the signature.
        extra_scores: extra information coming from the scoring function,
            this contains the transitions added to the replay buffer.

    Returns:
        New emitter state where the replay buffer has been filled with
        the new experienced transitions.
    """
    # Look only at the fitness corresponding to emitter state individuals
    fitnesses = fitnesses[self._config.ga_population_size_per_device :]
    fitnesses = jnp.ravel(fitnesses)
    training_states = emitter_state.training_states
    replay_buffers = emitter_state.replay_buffers
    genotypes = (training_states, replay_buffers)

    # Incremental algorithm to gather top best among the population on each device
    # First exchange
    indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
    num_best_local = int(
        self._config.pg_population_size_per_device
        * self._config.fraction_best_to_replace_from
    )
    indices_to_share = indices_to_share[:num_best_local]
    genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
        lambda x: x[indices_to_share], (genotypes, fitnesses)
    )
    gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes_to_share, fitnesses_to_share),
    )

    genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
    best_indices_stacked = jnp.argsort(-fitnesses_stacked)
    best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
    best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
        lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
    )

    # Define loop fn for the other exchanges
    def _loop_fn(i, val):  # type: ignore
        best_genotypes_local, best_fitnesses_local = val
        indices_to_share = jax.lax.dynamic_slice(
            jnp.arange(self._config.pg_population_size_per_device),
            [i * self._num_to_exchange],
            [self._num_to_exchange],
        )
        genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            (gathered_genotypes, gathered_fitnesses),
            (best_genotypes_local, best_fitnesses_local),
        )

        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[
            : self._num_best_to_replace_from
        ]
        best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
            lambda x: x[best_indices_stacked],
            (genotypes_stacked, fitnesses_stacked),
        )
        return (best_genotypes_local, best_fitnesses_local)  # type: ignore

    # Incrementally get the top fraction_best_to_replace_from best individuals
    # on each device
    (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
        lower=1,
        upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
        body_fun=_loop_fn,
        init_val=(best_genotypes_local, best_fitnesses_local),
    )

    # Gather fitnesses from all devices to rank locally against it
    all_fitnesses = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        fitnesses,
    )
    all_fitnesses = jnp.ravel(all_fitnesses)
    all_fitnesses = -jnp.sort(-all_fitnesses)
    random_key = emitter_state.random_key
    random_key, sub_key = jax.random.split(random_key)
    best_genotypes = jax.tree_util.tree_map(
        lambda x: jax.random.choice(
            sub_key, x, shape=(len(fitnesses),), replace=True
        ),
        best_genotypes_local,
    )
    best_training_states, best_replay_buffers = best_genotypes

    # Resample hyper-params
    best_training_states = jax.vmap(
        best_training_states.__class__.resample_hyperparams
    )(best_training_states)

    # Replace by individuals from the best
    lower_bound = all_fitnesses[-self._num_to_replace_from_best]
    cond = fitnesses <= lower_bound

    training_states = jax.tree_util.tree_map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_training_states,
        training_states,
    )
    replay_buffers = jax.tree_util.tree_map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_replay_buffers,
        replay_buffers,
    )

    # Replacing with samples from the ME repertoire
    if self._num_to_replace_from_samples > 0:
        me_samples, random_key = repertoire.sample(
            random_key, self._config.pg_population_size_per_device
        )
        # Resample hyper-params
        me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
        upper_bound = all_fitnesses[
            -self._num_to_replace_from_best - self._num_to_replace_from_samples
        ]
        cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
        training_states = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            me_samples,
            training_states,
        )

    # Train the agents
    env_states = emitter_state.env_states
    # Init optimizers state before training the population
    training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
        training_states
    )
    (training_states, env_states, replay_buffers), metrics = self._train_fn(
        training_states, env_states, replay_buffers
    )
    # Empty optimizers states to avoid storing the info in the RAM
    # and having too heavy repertoires
    training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
        training_states
    )

    # Update emitter state
    emitter_state = emitter_state.replace(
        training_states=training_states,
        replay_buffers=replay_buffers,
        env_states=env_states,
        random_key=random_key,
    )
    return emitter_state  # type: ignore