MAP-Elites Low-Spread (ME-LS)

ME-LS is a variant of MAP-Elites that thrives the search process towards solutions that are consistent in the behavior space for uncertain domains.

qdax.core.mels.MELS (MAPElites)

Core elements of the MAP-Elites Low-Spread algorithm.

Most methods in this class are inherited from MAPElites.

The same scoring function can be passed into both MAPElites and this class. We have overridden init such that it takes in the scoring function and wraps it such that every solution is evaluated num_samples times.

We also overrode the init method to use the MELSRepertoire instead of MapElitesRepertoire.

Source code in qdax/core/mels.py
class MELS(MAPElites):
    """Core elements of the MAP-Elites Low-Spread algorithm.

    Most methods in this class are inherited from MAPElites.

    The same scoring function can be passed into both MAPElites and this class.
    We have overridden __init__ such that it takes in the scoring function and
    wraps it such that every solution is evaluated `num_samples` times.

    We also overrode the init method to use the MELSRepertoire instead of
    MapElitesRepertoire.
    """

    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        emitter: Emitter,
        metrics_function: Callable[[MELSRepertoire], Metrics],
        num_samples: int,
    ) -> None:
        self._scoring_function = partial(
            multi_sample_scoring_function,
            scoring_fn=scoring_function,
            num_samples=num_samples,
        )
        self._emitter = emitter
        self._metrics_function = metrics_function
        self._num_samples = num_samples

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self,
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]:
        """Initialize a MAP-Elites Low-Spread repertoire with an initial
        population of genotypes. Requires the definition of centroids that can
        be computed with any method such as CVT or Euclidean mapping.

        Args:
            init_genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            random_key: a random key used for stochastic operations.

        Returns:
            A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter
            state, JAX random key).
        """
        # score initial genotypes
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # init the repertoire
        repertoire = MELSRepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            extra_scores=extra_scores,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )
        return repertoire, emitter_state, random_key

init(self, init_genotypes, centroids, random_key)

Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • init_genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • centroids (Centroid) – tessellation centroids of shape (batch_size, num_descriptors)

  • random_key (RNGKey) – a random key used for stochastic operations.

Returns:
  • Tuple[MELSRepertoire, Optional[EmitterState], RNGKey] – A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter state, JAX random key).

Source code in qdax/core/mels.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self,
    init_genotypes: Genotype,
    centroids: Centroid,
    random_key: RNGKey,
) -> Tuple[MELSRepertoire, Optional[EmitterState], RNGKey]:
    """Initialize a MAP-Elites Low-Spread repertoire with an initial
    population of genotypes. Requires the definition of centroids that can
    be computed with any method such as CVT or Euclidean mapping.

    Args:
        init_genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        random_key: a random key used for stochastic operations.

    Returns:
        A tuple of (initialized MAP-Elites Low-Spread repertoire, initial emitter
        state, JAX random key).
    """
    # score initial genotypes
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        init_genotypes, random_key
    )

    # init the repertoire
    repertoire = MELSRepertoire.init(
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
        extra_scores=extra_scores,
    )

    # get initial state of the emitter
    emitter_state, random_key = self._emitter.init(
        init_genotypes=init_genotypes, random_key=random_key
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )
    return repertoire, emitter_state, random_key