MAP Elites with Evolution Strategies (ME-ES)

To create an instance of ME-ES, one need to use an instance of MAP-Elites with the MEESEmitter, detailed below.

qdax.core.emitters.mees_emitter.MEESEmitter (Emitter)

Emitter reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217

One can choose between the three variants by setting use_explore and use_exploit: ME-ES exploit-explore: use_exploit=True and use_explore=True Alternates between num_optimizer_steps of fitness gradients and num_optimizer_steps of novelty gradients, resample parent from the archive every num_optimizer_steps steps. ME-ES exploit: use_exploit=True and use_explore=False Only uses fitness gradient, no novelty gradients, but resample parent from the archive every num_optimizer_steps steps. ME-ES explore: use_exploit=False and use_explore=True Only uses novelty gradient, no fitness gradients, but resample parent from the archive every num_optimizer_steps steps.

Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitter(Emitter):
    """
    Emitter reproducing the MAP-Elites-ES algorithm from
    "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al:
    https://dl.acm.org/doi/pdf/10.1145/3377930.3390217

    One can choose between the three variants by setting use_explore and use_exploit:
        ME-ES exploit-explore: use_exploit=True and use_explore=True
            Alternates between num_optimizer_steps of fitness gradients and
            num_optimizer_steps of novelty gradients, resample parent from the archive
            every num_optimizer_steps steps.
        ME-ES exploit: use_exploit=True and use_explore=False
            Only uses fitness gradient, no novelty gradients, but resample parent from
            the archive every num_optimizer_steps steps.
        ME-ES explore: use_exploit=False and use_explore=True
            Only uses novelty gradient, no fitness gradients, but resample parent from
            the archive every num_optimizer_steps steps.
    """

    def __init__(
        self,
        config: MEESConfig,
        total_generations: int,
        scoring_fn: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        num_descriptors: int,
    ) -> None:
        """Initialise the MAP-Elites-ES emitter.
        WARNING: total_generations is required to build the novelty archive.

        Args:
            config: algorithm config
            scoring_fn: used to evaluate the samples for the gradient estimate.
            total_generations: total number of generations for which the
                emitter will run, allow to initialise the novelty archive.
            num_descriptors: dimension of the descriptors, used to initialise
                the empty novelty archive.
        """
        self._config = config
        self._scoring_fn = scoring_fn
        self._total_generations = total_generations
        self._num_descriptors = num_descriptors

        # Initialise optimizer
        if self._config.adam_optimizer:
            self._optimizer = optax.adam(learning_rate=config.learning_rate)
        else:
            self._optimizer = optax.sgd(learning_rate=config.learning_rate)

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return 1

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[MEESEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the MEESEmitter, a new random key.
        """
        # Initialisation requires one initial genotype
        if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
            init_genotypes = jax.tree_util.tree_map(
                lambda x: x[0],
                init_genotypes,
            )

        # Initialise optimizer
        initial_optimizer_state = self._optimizer.init(init_genotypes)

        # Create empty Novelty archive
        if self._config.use_explore:
            novelty_archive = NoveltyArchive.init(
                self._total_generations, self._num_descriptors
            )
        else:
            novelty_archive = NoveltyArchive.init(
                self._config.novelty_nearest_neighbors, self._num_descriptors
            )

        # Create empty updated genotypes and fitness
        last_updated_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
            init_genotypes,
        )
        last_updated_fitnesses = -jnp.inf * jnp.ones(
            shape=self._config.last_updated_size
        )

        return (
            MEESEmitterState(
                initial_optimizer_state=initial_optimizer_state,
                optimizer_state=initial_optimizer_state,
                offspring=init_genotypes,
                generation_count=0,
                novelty_archive=novelty_archive,
                last_updated_genotypes=last_updated_genotypes,
                last_updated_fitnesses=last_updated_fitnesses,
                last_updated_position=0,
                random_key=random_key,
            ),
            random_key,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: MEESEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Return the offspring generated through gradient update.

        Params:
            repertoire: the MAP-Elites repertoire to sample from
            emitter_state
            random_key: a jax PRNG random key

        Returns:
            a new gradient offspring
            a new jax PRNG key
        """

        return emitter_state.offspring, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _sample_exploit(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Sample half of the time uniformly from the exploit_num_cell_sample
        highest-performing cells of the repertoire and half of the time uniformly
        from the exploit_num_cell_sample highest-performing cells among the
        last updated cells.

        Args:
            emitter_state: current emitter_state
            repertoire: the current repertoire
            random_key: a jax PRNG random key

        Returns:
            samples: a genotype sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        def _sample(
            random_key: RNGKey,
            genotypes: Genotype,
            fitnesses: Fitness,
        ) -> Tuple[Genotype, RNGKey]:
            """Sample uniformly from the 2 highest fitness cells."""

            max_fitnesses, _ = jax.lax.top_k(
                fitnesses, self._config.exploit_num_cell_sample
            )
            min_fitness = jnp.nanmin(
                jnp.where(max_fitnesses > -jnp.inf, max_fitnesses, jnp.inf)
            )
            genotypes_empty = fitnesses < min_fitness
            p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty)
            random_key, subkey = jax.random.split(random_key)
            samples = jax.tree_map(
                lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
                genotypes,
            )
            return samples, random_key

        random_key, subkey = jax.random.split(random_key)

        # Sample p uniformly
        p = jax.random.uniform(subkey)

        # Depending on the value of p, use one of the two sampling options
        repertoire_sample = partial(
            _sample, genotypes=repertoire.genotypes, fitnesses=repertoire.fitnesses
        )
        last_updated_sample = partial(
            _sample,
            genotypes=emitter_state.last_updated_genotypes,
            fitnesses=emitter_state.last_updated_fitnesses,
        )
        samples, random_key = jax.lax.cond(
            p < 0.5,
            repertoire_sample,
            last_updated_sample,
            random_key,
        )

        return samples, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _sample_explore(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Sample uniformly from the explore_num_cell_sample most-novel genotypes.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            random_key: a jax PRNG random key

        Returns:
            samples: a genotype sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        # Compute the novelty of all indivs in the archive
        novelties = emitter_state.novelty_archive.novelty(
            repertoire.descriptors, self._config.novelty_nearest_neighbors
        )
        novelties = jnp.where(repertoire.fitnesses > -jnp.inf, novelties, -jnp.inf)

        # Sample uniformly for the explore_num_cell_sample most novel cells
        max_novelties, _ = jax.lax.top_k(
            novelties, self._config.explore_num_cell_sample
        )
        min_novelty = jnp.nanmin(
            jnp.where(max_novelties > -jnp.inf, max_novelties, jnp.inf)
        )
        repertoire_empty = novelties < min_novelty
        p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
        random_key, subkey = jax.random.split(random_key)
        samples = jax.tree_map(
            lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
            repertoire.genotypes,
        )

        return samples, random_key

    @partial(
        jax.jit,
        static_argnames=("self", "scores_fn"),
    )
    def _es_emitter(
        self,
        parent: Genotype,
        optimizer_state: optax.OptState,
        random_key: RNGKey,
        scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray],
    ) -> Tuple[Genotype, optax.OptState, RNGKey]:
        """Main es component, given a parent and a way to infer the score from
        the fitnesses and descriptors fo its es-samples, return its
        approximated-gradient-generated offspring.

        Args:
            parent: the considered parent.
            scores_fn: a function to infer the score of its es-samples from
                their fitness and descriptors.
            random_key

        Returns:
            The approximated-gradients-generated offspring and a new random_key.
        """

        random_key, subkey = jax.random.split(random_key)

        # Sampling mirror noise
        total_sample_number = self._config.sample_number
        if self._config.sample_mirror:

            sample_number = total_sample_number // 2
            half_sample_noise = jax.tree_util.tree_map(
                lambda x: jax.random.normal(
                    key=subkey,
                    shape=jnp.repeat(x, sample_number, axis=0).shape,
                ),
                parent,
            )
            sample_noise = jax.tree_util.tree_map(
                lambda x: jnp.concatenate(
                    [jnp.expand_dims(x, axis=1), jnp.expand_dims(-x, axis=1)], axis=1
                ).reshape(jnp.repeat(x, 2, axis=0).shape),
                half_sample_noise,
            )
            gradient_noise = half_sample_noise

        # Sampling non-mirror noise
        else:
            sample_number = total_sample_number
            sample_noise = jax.tree_map(
                lambda x: jax.random.normal(
                    key=subkey,
                    shape=jnp.repeat(x, sample_number, axis=0).shape,
                ),
                parent,
            )
            gradient_noise = sample_noise

        # Applying noise
        samples = jax.tree_map(
            lambda x: jnp.repeat(x, total_sample_number, axis=0),
            parent,
        )
        samples = jax.tree_map(
            lambda mean, noise: mean + self._config.sample_sigma * noise,
            samples,
            sample_noise,
        )

        # Evaluating samples
        fitnesses, descriptors, extra_scores, random_key = self._scoring_fn(
            samples, random_key
        )

        # Computing rank, with or without normalisation
        scores = scores_fn(fitnesses, descriptors)

        if self._config.sample_rank_norm:
            ranking_indices = jnp.argsort(scores, axis=0)
            ranks = jnp.argsort(ranking_indices, axis=0)
            ranks = (ranks / (total_sample_number - 1)) - 0.5

        else:
            ranks = scores

        # Reshaping rank to match shape of genotype_noise
        if self._config.sample_mirror:
            ranks = jnp.reshape(ranks, (sample_number, 2))
            ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks)
        ranks = jax.tree_map(
            lambda x: jnp.reshape(
                jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape
            ),
            gradient_noise,
        )

        # Computing the gradients
        gradient = jax.tree_map(
            lambda noise, rank: jnp.multiply(noise, rank),
            gradient_noise,
            ranks,
        )
        gradient = jax.tree_map(
            lambda x: jnp.reshape(x, (sample_number, -1)),
            gradient,
        )
        gradient = jax.tree_map(
            lambda g, p: jnp.reshape(
                -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma),
                p.shape,
            ),
            gradient,
            parent,
        )

        # Adding regularisation
        gradient = jax.tree_map(
            lambda g, p: g + self._config.l2_coefficient * p,
            gradient,
            parent,
        )

        # Applying gradients
        (offspring_update, optimizer_state) = self._optimizer.update(
            gradient, optimizer_state
        )
        offspring = optax.apply_updates(parent, offspring_update)

        return offspring, optimizer_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _buffers_update(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
    ) -> MEESEmitterState:
        """Update the different buffers and archives in the emitter
        state to generate the offspring for the next generation.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.

        Returns:
            The modified emitter state.
        """

        # Updating novelty archive
        novelty_archive = emitter_state.novelty_archive.update(descriptors)

        # Check if genotype from previous iteration has been added to the grid
        indice = get_cells_indices(descriptors, repertoire.centroids)
        added_genotype = jnp.all(
            jnp.asarray(
                jax.tree_util.tree_leaves(
                    jax.tree_util.tree_map(
                        lambda new_gen, rep_gen: jnp.all(
                            jnp.equal(
                                jnp.ravel(new_gen), jnp.ravel(rep_gen.at[indice].get())
                            ),
                            axis=0,
                        ),
                        genotypes,
                        repertoire.genotypes,
                    ),
                )
            ),
            axis=0,
        )

        # Update last_updated buffers
        last_updated_position = jnp.where(
            added_genotype,
            emitter_state.last_updated_position,
            self._config.last_updated_size + 1,
        )
        last_updated_fitnesses = emitter_state.last_updated_fitnesses
        last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set(
            fitnesses[0]
        )
        last_updated_genotypes = jax.tree_map(
            lambda last_gen, gen: last_gen.at[
                jnp.expand_dims(last_updated_position, axis=0)
            ].set(gen),
            emitter_state.last_updated_genotypes,
            genotypes,
        )
        last_updated_position = (
            emitter_state.last_updated_position + added_genotype
        ) % self._config.last_updated_size

        # Return new emitter_state
        return emitter_state.replace(  # type: ignore
            novelty_archive=novelty_archive,
            last_updated_genotypes=last_updated_genotypes,
            last_updated_fitnesses=last_updated_fitnesses,
            last_updated_position=last_updated_position,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> MEESEmitterState:
        """Generate the gradient offspring for the next emitter call. Also
        update the novelty archive and generation count from current call.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """

        assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
            "ERROR: MAP-Elites-ES generates 1 offspring per generation, "
            + "batch_size should be 1, the inputed batch has size:"
            + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
        )

        # Update all the buffers and archives of the emitter_state
        emitter_state = self._buffers_update(
            emitter_state, repertoire, genotypes, fitnesses, descriptors
        )

        # Use new or previous parents and exploitation or exploration
        generation_count = emitter_state.generation_count
        sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
        use_exploration = (
            self._config.use_explore and not self._config.use_exploit
        ) or (
            self._config.use_explore
            and self._config.use_exploit
            and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
        )

        # Select parent and optimizer_state
        parent, random_key = jax.lax.cond(
            sample_new_parent,
            lambda emitter_state, repertoire, random_key: jax.lax.cond(
                use_exploration,
                self._sample_explore,
                self._sample_exploit,
                emitter_state,
                repertoire,
                random_key,
            ),
            lambda emitter_state, repertoire, random_key: (
                emitter_state.offspring,
                random_key,
            ),
            emitter_state,
            repertoire,
            emitter_state.random_key,
        )
        optimizer_state = jax.lax.cond(
            sample_new_parent,
            lambda _unused: emitter_state.initial_optimizer_state,
            lambda _unused: emitter_state.optimizer_state,
            (),
        )

        # Define scores for es process
        def exploration_exploitation_scores(
            fitnesses: Fitness, descriptors: Descriptor
        ) -> jnp.ndarray:
            scores = jax.lax.cond(
                use_exploration,
                lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
                    descriptors, self._config.novelty_nearest_neighbors
                ),
                lambda fitnesses, descriptors: fitnesses,
                fitnesses,
                descriptors,
            )
            return scores

        # Run es process
        offspring, optimizer_state, random_key = self._es_emitter(
            parent=parent,
            optimizer_state=optimizer_state,
            random_key=random_key,
            scores_fn=exploration_exploitation_scores,
        )

        return emitter_state.replace(  # type: ignore
            optimizer_state=optimizer_state,
            offspring=offspring,
            generation_count=generation_count + 1,
            random_key=random_key,
        )

batch_size: int property readonly

Returns:
  • int – the batch size emitted by the emitter.

__init__(self, config, total_generations, scoring_fn, num_descriptors) special

Initialise the MAP-Elites-ES emitter. WARNING: total_generations is required to build the novelty archive.

Parameters:
  • config (MEESConfig) – algorithm config

  • scoring_fn (Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]) – used to evaluate the samples for the gradient estimate.

  • total_generations (int) – total number of generations for which the emitter will run, allow to initialise the novelty archive.

  • num_descriptors (int) – dimension of the descriptors, used to initialise the empty novelty archive.

Source code in qdax/core/emitters/mees_emitter.py
def __init__(
    self,
    config: MEESConfig,
    total_generations: int,
    scoring_fn: Callable[
        [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
    ],
    num_descriptors: int,
) -> None:
    """Initialise the MAP-Elites-ES emitter.
    WARNING: total_generations is required to build the novelty archive.

    Args:
        config: algorithm config
        scoring_fn: used to evaluate the samples for the gradient estimate.
        total_generations: total number of generations for which the
            emitter will run, allow to initialise the novelty archive.
        num_descriptors: dimension of the descriptors, used to initialise
            the empty novelty archive.
    """
    self._config = config
    self._scoring_fn = scoring_fn
    self._total_generations = total_generations
    self._num_descriptors = num_descriptors

    # Initialise optimizer
    if self._config.adam_optimizer:
        self._optimizer = optax.adam(learning_rate=config.learning_rate)
    else:
        self._optimizer = optax.sgd(learning_rate=config.learning_rate)

init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Genotype) – The initial population.

  • random_key (RNGKey) – A random key.

Returns:
  • Tuple[MEESEmitterState, RNGKey] – The initial state of the MEESEmitter, a new random key.

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the MEESEmitter, a new random key.
    """
    # Initialisation requires one initial genotype
    if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
        init_genotypes = jax.tree_util.tree_map(
            lambda x: x[0],
            init_genotypes,
        )

    # Initialise optimizer
    initial_optimizer_state = self._optimizer.init(init_genotypes)

    # Create empty Novelty archive
    if self._config.use_explore:
        novelty_archive = NoveltyArchive.init(
            self._total_generations, self._num_descriptors
        )
    else:
        novelty_archive = NoveltyArchive.init(
            self._config.novelty_nearest_neighbors, self._num_descriptors
        )

    # Create empty updated genotypes and fitness
    last_updated_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
        init_genotypes,
    )
    last_updated_fitnesses = -jnp.inf * jnp.ones(
        shape=self._config.last_updated_size
    )

    return (
        MEESEmitterState(
            initial_optimizer_state=initial_optimizer_state,
            optimizer_state=initial_optimizer_state,
            offspring=init_genotypes,
            generation_count=0,
            novelty_archive=novelty_archive,
            last_updated_genotypes=last_updated_genotypes,
            last_updated_fitnesses=last_updated_fitnesses,
            last_updated_position=0,
            random_key=random_key,
        ),
        random_key,
    )

emit(self, repertoire, emitter_state, random_key)

Return the offspring generated through gradient update.

Parameters:
  • repertoire (MapElitesRepertoire) – the MAP-Elites repertoire to sample from

  • random_key (RNGKey) – a jax PRNG random key

Returns:
  • Tuple[Genotype, RNGKey] – a new gradient offspring a new jax PRNG key

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: MEESEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Return the offspring generated through gradient update.

    Params:
        repertoire: the MAP-Elites repertoire to sample from
        emitter_state
        random_key: a jax PRNG random key

    Returns:
        a new gradient offspring
        a new jax PRNG key
    """

    return emitter_state.offspring, random_key

state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Generate the gradient offspring for the next emitter call. Also update the novelty archive and generation count from current call.

Parameters:
  • emitter_state (MEESEmitterState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Genotype) – the genotypes of the batch of emitted offspring.

  • fitnesses (Fitness) – the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) – the descriptors of the emitted offspring.

  • extra_scores (ExtraScores) – a dictionary with other values outputted by the scoring function.

Returns:
  • MEESEmitterState – The modified emitter state.

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: MEESEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> MEESEmitterState:
    """Generate the gradient offspring for the next emitter call. Also
    update the novelty archive and generation count from current call.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring.
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: a dictionary with other values outputted by the
            scoring function.

    Returns:
        The modified emitter state.
    """

    assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
        "ERROR: MAP-Elites-ES generates 1 offspring per generation, "
        + "batch_size should be 1, the inputed batch has size:"
        + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
    )

    # Update all the buffers and archives of the emitter_state
    emitter_state = self._buffers_update(
        emitter_state, repertoire, genotypes, fitnesses, descriptors
    )

    # Use new or previous parents and exploitation or exploration
    generation_count = emitter_state.generation_count
    sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
    use_exploration = (
        self._config.use_explore and not self._config.use_exploit
    ) or (
        self._config.use_explore
        and self._config.use_exploit
        and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
    )

    # Select parent and optimizer_state
    parent, random_key = jax.lax.cond(
        sample_new_parent,
        lambda emitter_state, repertoire, random_key: jax.lax.cond(
            use_exploration,
            self._sample_explore,
            self._sample_exploit,
            emitter_state,
            repertoire,
            random_key,
        ),
        lambda emitter_state, repertoire, random_key: (
            emitter_state.offspring,
            random_key,
        ),
        emitter_state,
        repertoire,
        emitter_state.random_key,
    )
    optimizer_state = jax.lax.cond(
        sample_new_parent,
        lambda _unused: emitter_state.initial_optimizer_state,
        lambda _unused: emitter_state.optimizer_state,
        (),
    )

    # Define scores for es process
    def exploration_exploitation_scores(
        fitnesses: Fitness, descriptors: Descriptor
    ) -> jnp.ndarray:
        scores = jax.lax.cond(
            use_exploration,
            lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
                descriptors, self._config.novelty_nearest_neighbors
            ),
            lambda fitnesses, descriptors: fitnesses,
            fitnesses,
            descriptors,
        )
        return scores

    # Run es process
    offspring, optimizer_state, random_key = self._es_emitter(
        parent=parent,
        optimizer_state=optimizer_state,
        random_key=random_key,
        scores_fn=exploration_exploitation_scores,
    )

    return emitter_state.replace(  # type: ignore
        optimizer_state=optimizer_state,
        offspring=offspring,
        generation_count=generation_count + 1,
        random_key=random_key,
    )