MAP Elites Population Based Training (ME PBT)

ME PBT is a recent algorithm combining MAP Elites with Population Based Training to evolve a population of diverse RL agents.

To create an instance of PBTME, one need to use an instance of Distributed MAP-Elites with the PBTEmitter, detailed below.

Bases: Emitter

A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.

Source code in qdax/core/emitters/pbt_me_emitter.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
class PBTEmitter(Emitter):
    """
    A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
    (PGA-Map-Elites) algorithm.
    """

    def __init__(
        self,
        pbt_agent: Union[PBTSAC, PBTTD3],
        config: PBTEmitterConfig,
        env: QDEnv,
        variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
        selector: Optional[Selector] = None,
    ) -> None:

        # Parameters internalization
        self._env = env
        self._variation_fn = variation_fn
        self._config = config
        self._agent = pbt_agent
        self._train_fn = self._agent.get_train_fn(
            env=env,
            num_iterations=config.num_training_iterations,
            env_batch_size=config.env_batch_size,
            grad_updates_per_step=config.grad_updates_per_step,
        )

        # Compute numbers from fractions
        pg_population_size = config.pg_population_size_per_device * config.num_devices
        self._num_best_to_replace_from = int(
            pg_population_size * config.fraction_best_to_replace_from
        )
        self._num_to_replace_from_best = int(
            pg_population_size * config.fraction_to_replace_from_best
        )
        self._num_to_replace_from_samples = int(
            pg_population_size * config.fraction_to_replace_from_samples
        )
        self._num_to_exchange = int(
            config.pg_population_size_per_device * config.fraction_sort_exchange
        )

        self._selector = selector

    def init(
        self,
        key: RNGKey,
        repertoire: Repertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> PBTEmitterState:
        """Initializes the emitter state.

        Args:
            genotypes: The initial population.
            key: A random key.

        Returns:
            The initial state of the PGAMEEmitter.
        """

        observation_size = self._env.observation_size
        action_size = self._env.action_size

        # Initialise replay buffers
        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=self._config.pg_population_size_per_device
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=self._config.buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)

        # Initialise env states
        key, subkey = jax.random.split(key)
        env_states = jax.jit(self._env.reset)(rng=subkey)

        reshape_fn = jax.jit(
            lambda tree: jax.tree.map(
                lambda x: jnp.reshape(
                    x,
                    (
                        self._config.pg_population_size_per_device,
                        self._config.env_batch_size,
                    )
                    + x.shape[1:],
                ),
                tree,
            ),
        )
        env_states = reshape_fn(env_states)

        # Create emitter state
        # keep only pg population size training states if more are provided
        genotypes = jax.tree.map(
            lambda x: x[: self._config.pg_population_size_per_device], genotypes
        )
        emitter_state = PBTEmitterState(
            replay_buffers=replay_buffers,
            env_states=env_states,
            training_states=genotypes,
            key=key,
        )

        return emitter_state

    def emit(  # type: ignore
        self,
        repertoire: GARepertoire,
        emitter_state: PBTEmitterState,
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """Do a single PGA-ME iteration: train critics and greedy policy,
        make mutations (evo and pg), score solution, fill replay buffer and insert back
        in the MAP-Elites grid.

        Args:
            repertoire: the current repertoire of genotypes
            emitter_state: the state of the emitter used
            key: a random key

        Returns:
            A batch of offspring, the new emitter state and a new key.
        """

        # Mutation PG (the mutation has already been performed during the state update)
        x_mutation_pg = emitter_state.training_states

        # Mutation evo
        if self._config.ga_population_size_per_device > 0:
            mutation_ga_batch_size = self._config.ga_population_size_per_device
            sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3)
            x1 = repertoire.select(
                sample_key_1, mutation_ga_batch_size, selector=self._selector
            ).genotypes
            x2 = repertoire.select(
                sample_key_2, mutation_ga_batch_size, selector=self._selector
            ).genotypes
            x_mutation_ga = self._variation_fn(x1, x2, variation_key)

            # Gather offspring
            genotypes = jax.tree.map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                x_mutation_ga,
                x_mutation_pg,
            )
        else:
            genotypes = x_mutation_pg

        return genotypes, {}

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        mutation_pg_batch_size = self._config.pg_population_size_per_device
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        return mutation_pg_batch_size + mutation_ga_batch_size

    def state_update(  # type: ignore
        self,
        emitter_state: PBTEmitterState,
        repertoire: GARepertoire,
        genotypes: Optional[Genotype],
        fitnesses: Fitness,
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> PBTEmitterState:
        """
        Update the internal emitter state. I.e. update the population replay buffers and
        agents.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # Look only at the fitness corresponding to emitter state individuals
        fitnesses = fitnesses[self._config.ga_population_size_per_device :]
        fitnesses = jnp.ravel(fitnesses)
        training_states = emitter_state.training_states
        replay_buffers = emitter_state.replay_buffers
        genotypes = (training_states, replay_buffers)

        # Incremental algorithm to gather top best among the population on each device
        # First exchange
        indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
        num_best_local = int(
            self._config.pg_population_size_per_device
            * self._config.fraction_best_to_replace_from
        )
        indices_to_share = indices_to_share[:num_best_local]
        genotypes_to_share, fitnesses_to_share = jax.tree.map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
        best_genotypes_local, best_fitnesses_local = jax.tree.map(
            lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
        )

        # Define loop fn for the other exchanges
        def _loop_fn(i, val):  # type: ignore
            best_genotypes_local, best_fitnesses_local = val
            indices_to_share = jax.lax.dynamic_slice(
                jnp.arange(self._config.pg_population_size_per_device),
                [i * self._num_to_exchange],
                [self._num_to_exchange],
            )
            genotypes_to_share, fitnesses_to_share = jax.tree.map(
                lambda x: x[indices_to_share], (genotypes, fitnesses)
            )
            gathered_genotypes, gathered_fitnesses = jax.tree.map(
                lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
                (genotypes_to_share, fitnesses_to_share),
            )

            genotypes_stacked, fitnesses_stacked = jax.tree.map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                (gathered_genotypes, gathered_fitnesses),
                (best_genotypes_local, best_fitnesses_local),
            )

            best_indices_stacked = jnp.argsort(-fitnesses_stacked)
            best_indices_stacked = best_indices_stacked[
                : self._num_best_to_replace_from
            ]
            best_genotypes_local, best_fitnesses_local = jax.tree.map(
                lambda x: x[best_indices_stacked],
                (genotypes_stacked, fitnesses_stacked),
            )
            return (best_genotypes_local, best_fitnesses_local)  # type: ignore

        # Incrementally get the top fraction_best_to_replace_from best individuals
        # on each device
        (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
            lower=1,
            upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
            body_fun=_loop_fn,
            init_val=(best_genotypes_local, best_fitnesses_local),
        )

        # Gather fitnesses from all devices to rank locally against it
        all_fitnesses = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            fitnesses,
        )
        all_fitnesses = jnp.ravel(all_fitnesses)
        all_fitnesses = -jnp.sort(-all_fitnesses)
        key = emitter_state.key
        key, subkey = jax.random.split(key)
        best_genotypes = jax.tree.map(
            lambda x: jax.random.choice(
                subkey, x, shape=(len(fitnesses),), replace=True
            ),
            best_genotypes_local,
        )
        best_training_states, best_replay_buffers = best_genotypes

        # Resample hyper-params
        best_training_states = jax.vmap(
            best_training_states.__class__.resample_hyperparams
        )(best_training_states)

        # Replace by individuals from the best
        lower_bound = all_fitnesses[-self._num_to_replace_from_best]
        cond = fitnesses <= lower_bound

        training_states = jax.tree.map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_training_states,
            training_states,
        )
        replay_buffers = jax.tree.map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_replay_buffers,
            replay_buffers,
        )

        # Replacing with samples from the ME repertoire
        if self._num_to_replace_from_samples > 0:
            key, subkey = jax.random.split(key)
            me_samples = repertoire.select(
                subkey,
                self._config.pg_population_size_per_device,
                selector=self._selector,
            ).genotypes
            # Resample hyper-params
            me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
            upper_bound = all_fitnesses[
                -self._num_to_replace_from_best - self._num_to_replace_from_samples
            ]
            cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
            training_states = jax.tree.map(
                lambda x, y: jnp.where(
                    jnp.expand_dims(
                        cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                    ),
                    x,
                    y,
                ),
                me_samples,
                training_states,
            )

        # Train the agents
        env_states = emitter_state.env_states
        # Init optimizers state before training the population
        training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
            training_states
        )
        (training_states, env_states, replay_buffers), metrics = self._train_fn(
            training_states, env_states, replay_buffers
        )
        # Empty optimizers states to avoid storing the info in the RAM
        # and having too heavy repertoires
        training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
            training_states
        )

        # Update emitter state
        emitter_state = emitter_state.replace(
            training_states=training_states,
            replay_buffers=replay_buffers,
            env_states=env_states,
            key=key,
        )
        return emitter_state  # type: ignore

batch_size property

Returns:
  • int

    the batch size emitted by the emitter.

emit(repertoire, emitter_state, key)

Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid.

Parameters:
  • repertoire (GARepertoire) –

    the current repertoire of genotypes

  • emitter_state (PBTEmitterState) –

    the state of the emitter used

  • key (RNGKey) –

    a random key

Returns:
  • Tuple[Genotype, ExtraScores]

    A batch of offspring, the new emitter state and a new key.

Source code in qdax/core/emitters/pbt_me_emitter.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def emit(  # type: ignore
    self,
    repertoire: GARepertoire,
    emitter_state: PBTEmitterState,
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """Do a single PGA-ME iteration: train critics and greedy policy,
    make mutations (evo and pg), score solution, fill replay buffer and insert back
    in the MAP-Elites grid.

    Args:
        repertoire: the current repertoire of genotypes
        emitter_state: the state of the emitter used
        key: a random key

    Returns:
        A batch of offspring, the new emitter state and a new key.
    """

    # Mutation PG (the mutation has already been performed during the state update)
    x_mutation_pg = emitter_state.training_states

    # Mutation evo
    if self._config.ga_population_size_per_device > 0:
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        sample_key_1, sample_key_2, variation_key = jax.random.split(key, 3)
        x1 = repertoire.select(
            sample_key_1, mutation_ga_batch_size, selector=self._selector
        ).genotypes
        x2 = repertoire.select(
            sample_key_2, mutation_ga_batch_size, selector=self._selector
        ).genotypes
        x_mutation_ga = self._variation_fn(x1, x2, variation_key)

        # Gather offspring
        genotypes = jax.tree.map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            x_mutation_ga,
            x_mutation_pg,
        )
    else:
        genotypes = x_mutation_pg

    return genotypes, {}

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initializes the emitter state.

Parameters:
  • genotypes (Genotype) –

    The initial population.

  • key (RNGKey) –

    A random key.

Returns:
  • PBTEmitterState

    The initial state of the PGAMEEmitter.

Source code in qdax/core/emitters/pbt_me_emitter.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def init(
    self,
    key: RNGKey,
    repertoire: Repertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> PBTEmitterState:
    """Initializes the emitter state.

    Args:
        genotypes: The initial population.
        key: A random key.

    Returns:
        The initial state of the PGAMEEmitter.
    """

    observation_size = self._env.observation_size
    action_size = self._env.action_size

    # Initialise replay buffers
    init_dummy_transition = partial(
        Transition.init_dummy,
        observation_dim=observation_size,
        action_dim=action_size,
    )
    init_dummy_transition = jax.vmap(
        init_dummy_transition, axis_size=self._config.pg_population_size_per_device
    )
    dummy_transitions = init_dummy_transition()

    replay_buffer_init = partial(
        ReplayBuffer.init,
        buffer_size=self._config.buffer_size,
    )
    replay_buffer_init = jax.vmap(replay_buffer_init)
    replay_buffers = replay_buffer_init(transition=dummy_transitions)

    # Initialise env states
    key, subkey = jax.random.split(key)
    env_states = jax.jit(self._env.reset)(rng=subkey)

    reshape_fn = jax.jit(
        lambda tree: jax.tree.map(
            lambda x: jnp.reshape(
                x,
                (
                    self._config.pg_population_size_per_device,
                    self._config.env_batch_size,
                )
                + x.shape[1:],
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)

    # Create emitter state
    # keep only pg population size training states if more are provided
    genotypes = jax.tree.map(
        lambda x: x[: self._config.pg_population_size_per_device], genotypes
    )
    emitter_state = PBTEmitterState(
        replay_buffers=replay_buffers,
        env_states=env_states,
        training_states=genotypes,
        key=key,
    )

    return emitter_state

state_update(emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Update the internal emitter state. I.e. update the population replay buffers and agents.

Parameters:
  • emitter_state (PBTEmitterState) –

    current emitter state.

  • repertoire (GARepertoire) –

    the current genotypes repertoire

  • genotypes (Optional[Genotype]) –

    unused here - but compulsory in the signature.

  • fitnesses (Fitness) –

    unused here - but compulsory in the signature.

  • descriptors (Optional[Descriptor]) –

    unused here - but compulsory in the signature.

  • extra_scores (ExtraScores) –

    extra information coming from the scoring function, this contains the transitions added to the replay buffer.

Returns:
  • PBTEmitterState

    New emitter state where the replay buffer has been filled with

  • PBTEmitterState

    the new experienced transitions.

Source code in qdax/core/emitters/pbt_me_emitter.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def state_update(  # type: ignore
    self,
    emitter_state: PBTEmitterState,
    repertoire: GARepertoire,
    genotypes: Optional[Genotype],
    fitnesses: Fitness,
    descriptors: Optional[Descriptor],
    extra_scores: ExtraScores,
) -> PBTEmitterState:
    """
    Update the internal emitter state. I.e. update the population replay buffers and
    agents.

    Args:
        emitter_state: current emitter state.
        repertoire: the current genotypes repertoire
        genotypes: unused here - but compulsory in the signature.
        fitnesses: unused here - but compulsory in the signature.
        descriptors: unused here - but compulsory in the signature.
        extra_scores: extra information coming from the scoring function,
            this contains the transitions added to the replay buffer.

    Returns:
        New emitter state where the replay buffer has been filled with
        the new experienced transitions.
    """
    # Look only at the fitness corresponding to emitter state individuals
    fitnesses = fitnesses[self._config.ga_population_size_per_device :]
    fitnesses = jnp.ravel(fitnesses)
    training_states = emitter_state.training_states
    replay_buffers = emitter_state.replay_buffers
    genotypes = (training_states, replay_buffers)

    # Incremental algorithm to gather top best among the population on each device
    # First exchange
    indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
    num_best_local = int(
        self._config.pg_population_size_per_device
        * self._config.fraction_best_to_replace_from
    )
    indices_to_share = indices_to_share[:num_best_local]
    genotypes_to_share, fitnesses_to_share = jax.tree.map(
        lambda x: x[indices_to_share], (genotypes, fitnesses)
    )
    gathered_genotypes, gathered_fitnesses = jax.tree.map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes_to_share, fitnesses_to_share),
    )

    genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
    best_indices_stacked = jnp.argsort(-fitnesses_stacked)
    best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
    best_genotypes_local, best_fitnesses_local = jax.tree.map(
        lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
    )

    # Define loop fn for the other exchanges
    def _loop_fn(i, val):  # type: ignore
        best_genotypes_local, best_fitnesses_local = val
        indices_to_share = jax.lax.dynamic_slice(
            jnp.arange(self._config.pg_population_size_per_device),
            [i * self._num_to_exchange],
            [self._num_to_exchange],
        )
        genotypes_to_share, fitnesses_to_share = jax.tree.map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = jax.tree.map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            (gathered_genotypes, gathered_fitnesses),
            (best_genotypes_local, best_fitnesses_local),
        )

        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[
            : self._num_best_to_replace_from
        ]
        best_genotypes_local, best_fitnesses_local = jax.tree.map(
            lambda x: x[best_indices_stacked],
            (genotypes_stacked, fitnesses_stacked),
        )
        return (best_genotypes_local, best_fitnesses_local)  # type: ignore

    # Incrementally get the top fraction_best_to_replace_from best individuals
    # on each device
    (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
        lower=1,
        upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
        body_fun=_loop_fn,
        init_val=(best_genotypes_local, best_fitnesses_local),
    )

    # Gather fitnesses from all devices to rank locally against it
    all_fitnesses = jax.tree.map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        fitnesses,
    )
    all_fitnesses = jnp.ravel(all_fitnesses)
    all_fitnesses = -jnp.sort(-all_fitnesses)
    key = emitter_state.key
    key, subkey = jax.random.split(key)
    best_genotypes = jax.tree.map(
        lambda x: jax.random.choice(
            subkey, x, shape=(len(fitnesses),), replace=True
        ),
        best_genotypes_local,
    )
    best_training_states, best_replay_buffers = best_genotypes

    # Resample hyper-params
    best_training_states = jax.vmap(
        best_training_states.__class__.resample_hyperparams
    )(best_training_states)

    # Replace by individuals from the best
    lower_bound = all_fitnesses[-self._num_to_replace_from_best]
    cond = fitnesses <= lower_bound

    training_states = jax.tree.map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_training_states,
        training_states,
    )
    replay_buffers = jax.tree.map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_replay_buffers,
        replay_buffers,
    )

    # Replacing with samples from the ME repertoire
    if self._num_to_replace_from_samples > 0:
        key, subkey = jax.random.split(key)
        me_samples = repertoire.select(
            subkey,
            self._config.pg_population_size_per_device,
            selector=self._selector,
        ).genotypes
        # Resample hyper-params
        me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
        upper_bound = all_fitnesses[
            -self._num_to_replace_from_best - self._num_to_replace_from_samples
        ]
        cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
        training_states = jax.tree.map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            me_samples,
            training_states,
        )

    # Train the agents
    env_states = emitter_state.env_states
    # Init optimizers state before training the population
    training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
        training_states
    )
    (training_states, env_states, replay_buffers), metrics = self._train_fn(
        training_states, env_states, replay_buffers
    )
    # Empty optimizers states to avoid storing the info in the RAM
    # and having too heavy repertoires
    training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
        training_states
    )

    # Update emitter state
    emitter_state = emitter_state.replace(
        training_states=training_states,
        replay_buffers=replay_buffers,
        env_states=env_states,
        key=key,
    )
    return emitter_state  # type: ignore