MOME class

qdax.core.mome.MOME (MAPElites)

Implements Multi-Objectives MAP Elites.

Note: most functions are inherited from MAPElites. The only function that had to be overwritten is the init function as it has to take into account the specificities of the the Multi Objective repertoire.

Source code in qdax/core/mome.py
class MOME(MAPElites):
    """Implements Multi-Objectives MAP Elites.

    Note: most functions are inherited from MAPElites. The only function
    that had to be overwritten is the init function as it has to take
    into account the specificities of the the Multi Objective repertoire.
    """

    @partial(jax.jit, static_argnames=("self", "pareto_front_max_length"))
    def init(
        self,
        init_genotypes: jnp.ndarray,
        centroids: Centroid,
        pareto_front_max_length: int,
        random_key: RNGKey,
    ) -> Tuple[MOMERepertoire, Optional[EmitterState], RNGKey]:
        """Initialize a MOME grid with an initial population of genotypes. Requires
        the definition of centroids that can be computed with any method such as
        CVT or Euclidean mapping.

        Args:
            init_genotypes: genotypes of the initial population.
            centroids: centroids of the repertoire.
            pareto_front_max_length: maximum size of the pareto front. This is
                necessary to respect jax.jit fixed shape size constraint.
            random_key: a random key to handle stochasticity.

        Returns:
            The initial repertoire and emitter state, and a new random key.
        """

        # first score
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # init the repertoire
        repertoire = MOMERepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            pareto_front_max_length=pareto_front_max_length,
            extra_scores=extra_scores,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

init(self, init_genotypes, centroids, pareto_front_max_length, random_key)

Initialize a MOME grid with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • init_genotypes (jnp.ndarray) – genotypes of the initial population.

  • centroids (Centroid) – centroids of the repertoire.

  • pareto_front_max_length (int) – maximum size of the pareto front. This is necessary to respect jax.jit fixed shape size constraint.

  • random_key (RNGKey) – a random key to handle stochasticity.

Returns:
  • Tuple[MOMERepertoire, Optional[EmitterState], RNGKey] – The initial repertoire and emitter state, and a new random key.

Source code in qdax/core/mome.py
@partial(jax.jit, static_argnames=("self", "pareto_front_max_length"))
def init(
    self,
    init_genotypes: jnp.ndarray,
    centroids: Centroid,
    pareto_front_max_length: int,
    random_key: RNGKey,
) -> Tuple[MOMERepertoire, Optional[EmitterState], RNGKey]:
    """Initialize a MOME grid with an initial population of genotypes. Requires
    the definition of centroids that can be computed with any method such as
    CVT or Euclidean mapping.

    Args:
        init_genotypes: genotypes of the initial population.
        centroids: centroids of the repertoire.
        pareto_front_max_length: maximum size of the pareto front. This is
            necessary to respect jax.jit fixed shape size constraint.
        random_key: a random key to handle stochasticity.

    Returns:
        The initial repertoire and emitter state, and a new random key.
    """

    # first score
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        init_genotypes, random_key
    )

    # init the repertoire
    repertoire = MOMERepertoire.init(
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
        pareto_front_max_length=pareto_front_max_length,
        extra_scores=extra_scores,
    )

    # get initial state of the emitter
    emitter_state, random_key = self._emitter.init(
        init_genotypes=init_genotypes, random_key=random_key
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    return repertoire, emitter_state, random_key