MAP Elites class¶
This class implement the base mechanism of MAP-Elites. It must be used with an emitter. To get the usual MAP-Elites algorithm, one must use the mixing emitter.
The MAP-Elites class can be used with other emitters to create variants, like PGAME, CMA-MEGA and OMG-MEGA.
qdax.core.map_elites.MAPElites
¶
Core elements of the MAP-Elites algorithm.
Note: Although very similar to the GeneticAlgorithm, we decided to keep the MAPElites class independent of the GeneticAlgorithm class at the moment to keep elements explicit.
Parameters: |
|
---|
Source code in qdax/core/map_elites.py
class MAPElites:
"""Core elements of the MAP-Elites algorithm.
Note: Although very similar to the GeneticAlgorithm, we decided to keep the
MAPElites class independent of the GeneticAlgorithm class at the moment to keep
elements explicit.
Args:
scoring_function: a function that takes a batch of genotypes and compute
their fitnesses and descriptors
emitter: an emitter is used to suggest offsprings given a MAPELites
repertoire. It has two compulsory functions. A function that takes
emits a new population, and a function that update the internal state
of the emitter.
metrics_function: a function that takes a MAP-Elites repertoire and compute
any useful metric to track its evolution
"""
def __init__(
self,
scoring_function: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
emitter: Emitter,
metrics_function: Callable[[MapElitesRepertoire], Metrics],
) -> None:
self._scoring_function = scoring_function
self._emitter = emitter
self._metrics_function = metrics_function
@partial(jax.jit, static_argnames=("self",))
def init(
self,
init_genotypes: Genotype,
centroids: Centroid,
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
init_genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
random_key: a random key used for stochastic operations.
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter,
and a random key.
"""
# score initial genotypes
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
)
# init the repertoire
repertoire = MapElitesRepertoire.init(
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
extra_scores=extra_scores,
)
# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)
# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
return repertoire, emitter_state, random_key
@partial(jax.jit, static_argnames=("self",))
def update(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
"""
Performs one iteration of the MAP-Elites algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
are copied.
2. The copies are mutated and crossed-over
3. The obtained offsprings are scored and then added to the repertoire.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
random_key: a jax PRNG random key
Returns:
the updated MAP-Elites repertoire
the updated (if needed) emitter state
metrics about the updated repertoire
a new jax PRNG key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)
# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes, random_key
)
# add genotypes in the repertoire
repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics, random_key
@partial(jax.jit, static_argnames=("self",))
def scan_update(
self,
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.
Args:
carry: a tuple containing the repertoire, the emitter state and a
random key.
unused: unused element, necessary to respect jax.lax.scan API.
Returns:
The updated repertoire and emitter state, with a new random key and metrics.
"""
repertoire, emitter_state, random_key = carry
(repertoire, emitter_state, metrics, random_key,) = self.update(
repertoire,
emitter_state,
random_key,
)
return (repertoire, emitter_state, random_key), metrics
init(self, init_genotypes, centroids, random_key)
¶
Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def init(
self,
init_genotypes: Genotype,
centroids: Centroid,
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
init_genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
random_key: a random key used for stochastic operations.
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter,
and a random key.
"""
# score initial genotypes
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
)
# init the repertoire
repertoire = MapElitesRepertoire.init(
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
extra_scores=extra_scores,
)
# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)
# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
return repertoire, emitter_state, random_key
update(self, repertoire, emitter_state, random_key)
¶
Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
"""
Performs one iteration of the MAP-Elites algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
are copied.
2. The copies are mutated and crossed-over
3. The obtained offsprings are scored and then added to the repertoire.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
random_key: a jax PRNG random key
Returns:
the updated MAP-Elites repertoire
the updated (if needed) emitter state
metrics about the updated repertoire
a new jax PRNG key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)
# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes, random_key
)
# add genotypes in the repertoire
repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics, random_key
scan_update(self, carry, unused)
¶
Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/map_elites.py
@partial(jax.jit, static_argnames=("self",))
def scan_update(
self,
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.
Args:
carry: a tuple containing the repertoire, the emitter state and a
random key.
unused: unused element, necessary to respect jax.lax.scan API.
Returns:
The updated repertoire and emitter state, with a new random key and metrics.
"""
repertoire, emitter_state, random_key = carry
(repertoire, emitter_state, metrics, random_key,) = self.update(
repertoire,
emitter_state,
random_key,
)
return (repertoire, emitter_state, random_key), metrics
We also provide a class to have MAP-Elites efficiently distributed over several devices.
qdax.core.distributed_map_elites.DistributedMAPElites (MAPElites)
¶
Source code in qdax/core/distributed_map_elites.py
class DistributedMAPElites(MAPElites):
@partial(jax.jit, static_argnames=("self",))
def init(
self,
init_genotypes: Genotype,
centroids: Centroid,
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Before the repertoire is initialised, individuals are gathered from all the
devices.
Args:
init_genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
random_key: a random key used for stochastic operations.
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter,
and a random key.
"""
# score initial genotypes
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
)
# gather across all devices
(
gathered_genotypes,
gathered_fitnesses,
gathered_descriptors,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(init_genotypes, fitnesses, descriptors),
)
# init the repertoire
repertoire = MapElitesRepertoire.init(
genotypes=gathered_genotypes,
fitnesses=gathered_fitnesses,
descriptors=gathered_descriptors,
centroids=centroids,
)
# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)
# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
return repertoire, emitter_state, random_key
@partial(jax.jit, static_argnames=("self",))
def update(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
"""Performs one iteration of the MAP-Elites algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
are copied.
2. The copies are mutated and crossed-over
3. The obtained offsprings are scored and then added to the repertoire.
Before the repertoire is updated, individuals are gathered from all the
devices.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
random_key: a jax PRNG random key
Returns:
the updated MAP-Elites repertoire
the updated (if needed) emitter state
metrics about the updated repertoire
a new jax PRNG key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)
# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes, random_key
)
# gather across all devices
(
gathered_genotypes,
gathered_fitnesses,
gathered_descriptors,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes, fitnesses, descriptors),
)
# add genotypes in the repertoire
repertoire = repertoire.add(
gathered_genotypes, gathered_descriptors, gathered_fitnesses
)
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics, random_key
def get_distributed_init_fn(
self, centroids: Centroid, devices: List[Any]
) -> Callable[
[Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]
]:
"""Create a function that init MAP-Elites in a distributed way.
Args:
centroids: centroids that structure the repertoire.
devices: hardware devices.
Returns:
A callable function that inits the MAP-Elites algorithm in a ditributed way.
"""
return jax.pmap( # type: ignore
partial(self.init, centroids=centroids),
devices=devices,
axis_name="p",
)
def get_distributed_update_fn(
self, num_iterations: int, devices: List[Any]
) -> Callable[
[MapElitesRepertoire, Optional[EmitterState], RNGKey],
Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics],
]:
"""Create a function that can do a certain number of updates of
MAP-Elites in a way that is distributed on several devices.
Args:
num_iterations: number of iterations to realize.
devices: hardware devices to distribute on.
Returns:
The update function that can be called directly to apply a sequence
of MAP-Elites updates.
"""
@partial(jax.jit, static_argnames=("self",))
def _scan_update(
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive."""
# unwrap the input
repertoire, emitter_state, random_key = carry
# apply one step of update
(repertoire, emitter_state, metrics, random_key,) = self.update(
repertoire,
emitter_state,
random_key,
)
return (repertoire, emitter_state, random_key), metrics
def update_fn(
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]:
"""Apply num_iterations of update."""
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
_scan_update,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)
return repertoire, emitter_state, random_key, metrics
return jax.pmap(update_fn, devices=devices, axis_name="p") # type: ignore
init(self, init_genotypes, centroids, random_key)
¶
Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Before the repertoire is initialised, individuals are gathered from all the devices.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/distributed_map_elites.py
@partial(jax.jit, static_argnames=("self",))
def init(
self,
init_genotypes: Genotype,
centroids: Centroid,
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Before the repertoire is initialised, individuals are gathered from all the
devices.
Args:
init_genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
random_key: a random key used for stochastic operations.
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter,
and a random key.
"""
# score initial genotypes
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
)
# gather across all devices
(
gathered_genotypes,
gathered_fitnesses,
gathered_descriptors,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(init_genotypes, fitnesses, descriptors),
)
# init the repertoire
repertoire = MapElitesRepertoire.init(
genotypes=gathered_genotypes,
fitnesses=gathered_fitnesses,
descriptors=gathered_descriptors,
centroids=centroids,
)
# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)
# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
return repertoire, emitter_state, random_key
update(self, repertoire, emitter_state, random_key)
¶
Performs one iteration of the MAP-Elites algorithm.
- A batch of genotypes is sampled in the repertoire and the genotypes are copied.
- The copies are mutated and crossed-over
- The obtained offsprings are scored and then added to the repertoire.
Before the repertoire is updated, individuals are gathered from all the devices.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/distributed_map_elites.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics, RNGKey]:
"""Performs one iteration of the MAP-Elites algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
are copied.
2. The copies are mutated and crossed-over
3. The obtained offsprings are scored and then added to the repertoire.
Before the repertoire is updated, individuals are gathered from all the
devices.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
random_key: a jax PRNG random key
Returns:
the updated MAP-Elites repertoire
the updated (if needed) emitter state
metrics about the updated repertoire
a new jax PRNG key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)
# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes, random_key
)
# gather across all devices
(
gathered_genotypes,
gathered_fitnesses,
gathered_descriptors,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes, fitnesses, descriptors),
)
# add genotypes in the repertoire
repertoire = repertoire.add(
gathered_genotypes, gathered_descriptors, gathered_fitnesses
)
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics, random_key
get_distributed_init_fn(self, centroids, devices)
¶
Create a function that init MAP-Elites in a distributed way.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/distributed_map_elites.py
def get_distributed_init_fn(
self, centroids: Centroid, devices: List[Any]
) -> Callable[
[Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]
]:
"""Create a function that init MAP-Elites in a distributed way.
Args:
centroids: centroids that structure the repertoire.
devices: hardware devices.
Returns:
A callable function that inits the MAP-Elites algorithm in a ditributed way.
"""
return jax.pmap( # type: ignore
partial(self.init, centroids=centroids),
devices=devices,
axis_name="p",
)
get_distributed_update_fn(self, num_iterations, devices)
¶
Create a function that can do a certain number of updates of MAP-Elites in a way that is distributed on several devices.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/distributed_map_elites.py
def get_distributed_update_fn(
self, num_iterations: int, devices: List[Any]
) -> Callable[
[MapElitesRepertoire, Optional[EmitterState], RNGKey],
Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics],
]:
"""Create a function that can do a certain number of updates of
MAP-Elites in a way that is distributed on several devices.
Args:
num_iterations: number of iterations to realize.
devices: hardware devices to distribute on.
Returns:
The update function that can be called directly to apply a sequence
of MAP-Elites updates.
"""
@partial(jax.jit, static_argnames=("self",))
def _scan_update(
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive."""
# unwrap the input
repertoire, emitter_state, random_key = carry
# apply one step of update
(repertoire, emitter_state, metrics, random_key,) = self.update(
repertoire,
emitter_state,
random_key,
)
return (repertoire, emitter_state, random_key), metrics
def update_fn(
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey, Metrics]:
"""Apply num_iterations of update."""
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
_scan_update,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)
return repertoire, emitter_state, random_key, metrics
return jax.pmap(update_fn, devices=devices, axis_name="p") # type: ignore