MAP Elites class

This class implement the base mechanism of MAP-Elites. It must be used with an emitter. To get the usual MAP-Elites algorithm, one must use the mixing emitter.

The MAP-Elites class can be used with other emitters to create variants, like PGAME, CMA-MEGA and OMG-MEGA.

qdax.core.map_elites.MAPElites

Core elements of the MAP-Elites algorithm.

Note: Although very similar to the GeneticAlgorithm, we decided to keep the MAPElites class independent of the GeneticAlgorithm class at the moment to keep elements explicit.

Parameters:
  • scoring_function (Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]) – a function that takes a batch of genotypes and compute their fitnesses and descriptors

  • emitter (Emitter) – an emitter is used to suggest offsprings given a MAPELites repertoire. It has two compulsory functions. A function that takes emits a new population, and a function that update the internal state of the emitter.

  • metrics_function (Callable[[MapElitesRepertoire], Metrics]) – a function that takes a MAP-Elites repertoire and compute any useful metric to track its evolution

Source code in qdax/core/map_elites.py
class MAPElites:
    """Core elements of the MAP-Elites algorithm.

    Note: Although very similar to the GeneticAlgorithm, we decided to keep the
    MAPElites class independent of the GeneticAlgorithm class at the moment to keep
    elements explicit.

    Args:
        scoring_function: a function that takes a batch of genotypes and compute
            their fitnesses and descriptors
        emitter: an emitter is used to suggest offsprings given a MAPELites
            repertoire. It has two compulsory functions. A function that takes
            emits a new population, and a function that update the internal state
            of the emitter.
        metrics_function: a function that takes a MAP-Elites repertoire and compute
            any useful metric to track its evolution
    """

    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        emitter: Emitter,
        metrics_function: Callable[[MapElitesRepertoire], Metrics],
    ) -> None:
        self._scoring_function = scoring_function
        self._emitter = emitter
        self._metrics_function = metrics_function

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self,
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            init_genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            random_key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter,
            and a random key.
        """
        # score initial genotypes
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            extra_scores=extra_scores,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """
        Performs one iteration of the MAP-Elites algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.


        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            random_key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """
        # generate offsprings with the emitter
        genotypes, random_key = self._emitter.emit(
            repertoire, emitter_state, random_key
        )
        # scores the offsprings
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            genotypes, random_key
        )

        # add genotypes in the repertoire
        repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics, random_key

    @partial(jax.jit, static_argnames=("self",))
    def scan_update(
        self,
        carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
        unused: Any,
    ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive.

        Args:
            carry: a tuple containing the repertoire, the emitter state and a
                random key.
            unused: unused element, necessary to respect jax.lax.scan API.

        Returns:
            The updated repertoire and emitter state, with a new random key and metrics.
        """
        repertoire, emitter_state, random_key = carry
        (repertoire, emitter_state, metrics, random_key,) = self.update(
            repertoire,
            emitter_state,
            random_key,
        )

        return (repertoire, emitter_state, random_key), metrics

init(self, init_genotypes, centroids, random_key)

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • init_genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • centroids (Centroid) – tessellation centroids of shape (batch_size, num_descriptors)

  • random_key (RNGKey) – a random key used for stochastic operations.

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey] – An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key.

Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self,
    init_genotypes: Genotype,
    centroids: Centroid,
    random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Args:
        init_genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        random_key: a random key used for stochastic operations.

    Returns:
        An initialized MAP-Elite repertoire with the initial state of the emitter,
        and a random key.
    """
    # score initial genotypes
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        init_genotypes, random_key
    )

    # init the repertoire
    repertoire = MapElitesRepertoire.init(
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
        extra_scores=extra_scores,
    )

    # get initial state of the emitter
    emitter_state, random_key = self._emitter.init(
        init_genotypes=init_genotypes, random_key=random_key
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    return repertoire, emitter_state, random_key

update(self, repertoire, emitter_state, random_key)

Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the repertoire.

Parameters:
  • repertoire (MapElitesRepertoire) – the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) – state of the emitter

  • random_key (RNGKey) – a jax PRNG random key

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey] – the updated MAP-Elites repertoire the updated (if needed) emitter state metrics about the updated repertoire a new jax PRNG key

Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
    """
    Performs one iteration of the MAP-Elites algorithm.
    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.


    Args:
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
        random_key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """
    # generate offsprings with the emitter
    genotypes, random_key = self._emitter.emit(
        repertoire, emitter_state, random_key
    )
    # scores the offsprings
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        genotypes, random_key
    )

    # add genotypes in the repertoire
    repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics, random_key

scan_update(self, carry, unused)

Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.

Parameters:
  • carry (Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]) – a tuple containing the repertoire, the emitter state and a random key.

  • unused (Any) – unused element, necessary to respect jax.lax.scan API.

Returns:
  • Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics] – The updated repertoire and emitter state, with a new random key and metrics.

Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def scan_update(
    self,
    carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
    unused: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
    """Rewrites the update function in a way that makes it compatible with the
    jax.lax.scan primitive.

    Args:
        carry: a tuple containing the repertoire, the emitter state and a
            random key.
        unused: unused element, necessary to respect jax.lax.scan API.

    Returns:
        The updated repertoire and emitter state, with a new random key and metrics.
    """
    repertoire, emitter_state, random_key = carry
    (repertoire, emitter_state, metrics, random_key,) = self.update(
        repertoire,
        emitter_state,
        random_key,
    )

    return (repertoire, emitter_state, random_key), metrics

We also provide a class to have MAP-Elites efficiently distributed over several devices.

qdax.core.distributed_map_elites.DistributedMAPElites (MAPElites)

Source code in qdax/core/distributed_map_elites.py
class DistributedMAPElites(MAPElites):
    @partial(jax.jit, static_argnames=("self",))
    def init(
        self,
        init_genotypes: Genotype,
        centroids: Centroid,
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Before the repertoire is initialised, individuals are gathered from all the
        devices.

        Args:
            init_genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            random_key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter,
            and a random key.
        """
        # score initial genotypes
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # gather across all devices
        (
            gathered_genotypes,
            gathered_fitnesses,
            gathered_descriptors,
        ) = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (init_genotypes, fitnesses, descriptors),
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=gathered_genotypes,
            fitnesses=gathered_fitnesses,
            descriptors=gathered_descriptors,
            centroids=centroids,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """Performs one iteration of the MAP-Elites algorithm.

        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.

        Before the repertoire is updated, individuals are gathered from all the
        devices.

        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            random_key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """
        # generate offsprings with the emitter
        genotypes, random_key = self._emitter.emit(
            repertoire, emitter_state, random_key
        )
        # scores the offsprings
        fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
            genotypes, random_key
        )

        # gather across all devices
        (
            gathered_genotypes,
            gathered_fitnesses,
            gathered_descriptors,
        ) = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes, fitnesses, descriptors),
        )

        # add genotypes in the repertoire
        repertoire = repertoire.add(
            gathered_genotypes, gathered_descriptors, gathered_fitnesses
        )

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics, random_key

    def get_distributed_init_fn(
        self, centroids: Centroid, devices: List[Any]
    ) -> Callable[
        [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]
    ]:
        """Create a function that init MAP-Elites in a distributed way.

        Args:
            centroids: centroids that structure the repertoire.
            devices: hardware devices.

        Returns:
            A callable function that inits the MAP-Elites algorithm in a ditributed way.
        """
        return jax.pmap(  # type: ignore
            partial(self.init, centroids=centroids),
            devices=devices,
            axis_name="p",
        )

    def get_distributed_update_fn(
        self, num_iterations: int, devices: List[Any]
    ) -> Callable[
        [MapElitesRepertoire, Optional[EmitterState], RNGKey],
        Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics],
    ]:
        """Create a function that can do a certain number of updates of
        MAP-Elites in a way that is distributed on several devices.

        Args:
            num_iterations: number of iterations to realize.
            devices: hardware devices to distribute on.

        Returns:
            The update function that can be called directly to apply a sequence
            of MAP-Elites updates.
        """

        @partial(jax.jit, static_argnames=("self",))
        def _scan_update(
            carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
            unused: Any,
        ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
            """Rewrites the update function in a way that makes it compatible with the
            jax.lax.scan primitive."""
            # unwrap the input
            repertoire, emitter_state, random_key = carry

            # apply one step of update
            (repertoire, emitter_state, metrics, random_key,) = self.update(
                repertoire,
                emitter_state,
                random_key,
            )

            return (repertoire, emitter_state, random_key), metrics

        def update_fn(
            repertoire: MapElitesRepertoire,
            emitter_state: Optional[EmitterState],
            random_key: RNGKey,
        ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]:
            """Apply num_iterations of update."""
            (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
                _scan_update,
                (repertoire, emitter_state, random_key),
                (),
                length=num_iterations,
            )
            return repertoire, emitter_state, random_key, metrics

        return jax.pmap(update_fn, devices=devices, axis_name="p")  # type: ignore

init(self, init_genotypes, centroids, random_key)

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Before the repertoire is initialised, individuals are gathered from all the devices.

Parameters:
  • init_genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • centroids (Centroid) – tessellation centroids of shape (batch_size, num_descriptors)

  • random_key (RNGKey) – a random key used for stochastic operations.

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey] – An initialized MAP-Elite repertoire with the initial state of the emitter, and a random key.

Source code in qdax/core/distributed_map_elites.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self,
    init_genotypes: Genotype,
    centroids: Centroid,
    random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Before the repertoire is initialised, individuals are gathered from all the
    devices.

    Args:
        init_genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        random_key: a random key used for stochastic operations.

    Returns:
        An initialized MAP-Elite repertoire with the initial state of the emitter,
        and a random key.
    """
    # score initial genotypes
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        init_genotypes, random_key
    )

    # gather across all devices
    (
        gathered_genotypes,
        gathered_fitnesses,
        gathered_descriptors,
    ) = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (init_genotypes, fitnesses, descriptors),
    )

    # init the repertoire
    repertoire = MapElitesRepertoire.init(
        genotypes=gathered_genotypes,
        fitnesses=gathered_fitnesses,
        descriptors=gathered_descriptors,
        centroids=centroids,
    )

    # get initial state of the emitter
    emitter_state, random_key = self._emitter.init(
        init_genotypes=init_genotypes, random_key=random_key
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    return repertoire, emitter_state, random_key

update(self, repertoire, emitter_state, random_key)

Performs one iteration of the MAP-Elites algorithm.

  1. A batch of genotypes is sampled in the repertoire and the genotypes are copied.
  2. The copies are mutated and crossed-over
  3. The obtained offsprings are scored and then added to the repertoire.

Before the repertoire is updated, individuals are gathered from all the devices.

Parameters:
  • repertoire (MapElitesRepertoire) – the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) – state of the emitter

  • random_key (RNGKey) – a jax PRNG random key

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey] – the updated MAP-Elites repertoire the updated (if needed) emitter state metrics about the updated repertoire a new jax PRNG key

Source code in qdax/core/distributed_map_elites.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
    """Performs one iteration of the MAP-Elites algorithm.

    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.

    Before the repertoire is updated, individuals are gathered from all the
    devices.

    Args:
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
        random_key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """
    # generate offsprings with the emitter
    genotypes, random_key = self._emitter.emit(
        repertoire, emitter_state, random_key
    )
    # scores the offsprings
    fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
        genotypes, random_key
    )

    # gather across all devices
    (
        gathered_genotypes,
        gathered_fitnesses,
        gathered_descriptors,
    ) = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes, fitnesses, descriptors),
    )

    # add genotypes in the repertoire
    repertoire = repertoire.add(
        gathered_genotypes, gathered_descriptors, gathered_fitnesses
    )

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics, random_key

get_distributed_init_fn(self, centroids, devices)

Create a function that init MAP-Elites in a distributed way.

Parameters:
  • centroids (Centroid) – centroids that structure the repertoire.

  • devices (List[Any]) – hardware devices.

Returns:
  • Callable[[Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]] – A callable function that inits the MAP-Elites algorithm in a ditributed way.

Source code in qdax/core/distributed_map_elites.py
def get_distributed_init_fn(
    self, centroids: Centroid, devices: List[Any]
) -> Callable[
    [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]
]:
    """Create a function that init MAP-Elites in a distributed way.

    Args:
        centroids: centroids that structure the repertoire.
        devices: hardware devices.

    Returns:
        A callable function that inits the MAP-Elites algorithm in a ditributed way.
    """
    return jax.pmap(  # type: ignore
        partial(self.init, centroids=centroids),
        devices=devices,
        axis_name="p",
    )

get_distributed_update_fn(self, num_iterations, devices)

Create a function that can do a certain number of updates of MAP-Elites in a way that is distributed on several devices.

Parameters:
  • num_iterations (int) – number of iterations to realize.

  • devices (List[Any]) – hardware devices to distribute on.

Returns:
  • Callable[[MapElitesRepertoire, Optional[EmitterState], RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]] – The update function that can be called directly to apply a sequence of MAP-Elites updates.

Source code in qdax/core/distributed_map_elites.py
def get_distributed_update_fn(
    self, num_iterations: int, devices: List[Any]
) -> Callable[
    [MapElitesRepertoire, Optional[EmitterState], RNGKey],
    Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics],
]:
    """Create a function that can do a certain number of updates of
    MAP-Elites in a way that is distributed on several devices.

    Args:
        num_iterations: number of iterations to realize.
        devices: hardware devices to distribute on.

    Returns:
        The update function that can be called directly to apply a sequence
        of MAP-Elites updates.
    """

    @partial(jax.jit, static_argnames=("self",))
    def _scan_update(
        carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
        unused: Any,
    ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive."""
        # unwrap the input
        repertoire, emitter_state, random_key = carry

        # apply one step of update
        (repertoire, emitter_state, metrics, random_key,) = self.update(
            repertoire,
            emitter_state,
            random_key,
        )

        return (repertoire, emitter_state, random_key), metrics

    def update_fn(
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]:
        """Apply num_iterations of update."""
        (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
            _scan_update,
            (repertoire, emitter_state, random_key),
            (),
            length=num_iterations,
        )
        return repertoire, emitter_state, random_key, metrics

    return jax.pmap(update_fn, devices=devices, axis_name="p")  # type: ignore