MAP Elites with Evolution Strategies (ME-ES)¶
To create an instance of ME-ES, one need to use an instance of MAP-Elites with the MEESEmitter, detailed below.
qdax.core.emitters.mees_emitter.MEESEmitter (Emitter)
¶
Emitter reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217
One can choose between the three variants by setting use_explore and use_exploit: ME-ES exploit-explore: use_exploit=True and use_explore=True Alternates between num_optimizer_steps of fitness gradients and num_optimizer_steps of novelty gradients, resample parent from the archive every num_optimizer_steps steps. ME-ES exploit: use_exploit=True and use_explore=False Only uses fitness gradient, no novelty gradients, but resample parent from the archive every num_optimizer_steps steps. ME-ES explore: use_exploit=False and use_explore=True Only uses novelty gradient, no fitness gradients, but resample parent from the archive every num_optimizer_steps steps.
Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitter(Emitter):
"""
Emitter reproducing the MAP-Elites-ES algorithm from
"Scaling MAP-Elites to Deep Neuroevolution" by Colas et al:
https://dl.acm.org/doi/pdf/10.1145/3377930.3390217
One can choose between the three variants by setting use_explore and use_exploit:
ME-ES exploit-explore: use_exploit=True and use_explore=True
Alternates between num_optimizer_steps of fitness gradients and
num_optimizer_steps of novelty gradients, resample parent from the archive
every num_optimizer_steps steps.
ME-ES exploit: use_exploit=True and use_explore=False
Only uses fitness gradient, no novelty gradients, but resample parent from
the archive every num_optimizer_steps steps.
ME-ES explore: use_exploit=False and use_explore=True
Only uses novelty gradient, no fitness gradients, but resample parent from
the archive every num_optimizer_steps steps.
"""
def __init__(
self,
config: MEESConfig,
total_generations: int,
scoring_fn: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
num_descriptors: int,
) -> None:
"""Initialise the MAP-Elites-ES emitter.
WARNING: total_generations is required to build the novelty archive.
Args:
config: algorithm config
scoring_fn: used to evaluate the samples for the gradient estimate.
total_generations: total number of generations for which the
emitter will run, allow to initialise the novelty archive.
num_descriptors: dimension of the descriptors, used to initialise
the empty novelty archive.
"""
self._config = config
self._scoring_fn = scoring_fn
self._total_generations = total_generations
self._num_descriptors = num_descriptors
# Initialise optimizer
if self._config.adam_optimizer:
self._optimizer = optax.adam(learning_rate=config.learning_rate)
else:
self._optimizer = optax.sgd(learning_rate=config.learning_rate)
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return 1
@partial(
jax.jit,
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
)
# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
# Create empty Novelty archive
if self._config.use_explore:
novelty_archive = NoveltyArchive.init(
self._total_generations, self._num_descriptors
)
else:
novelty_archive = NoveltyArchive.init(
self._config.novelty_nearest_neighbors, self._num_descriptors
)
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
)
return (
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=0,
random_key=random_key,
),
random_key,
)
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: MEESEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Return the offspring generated through gradient update.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state
random_key: a jax PRNG random key
Returns:
a new gradient offspring
a new jax PRNG key
"""
return emitter_state.offspring, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _sample_exploit(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Sample half of the time uniformly from the exploit_num_cell_sample
highest-performing cells of the repertoire and half of the time uniformly
from the exploit_num_cell_sample highest-performing cells among the
last updated cells.
Args:
emitter_state: current emitter_state
repertoire: the current repertoire
random_key: a jax PRNG random key
Returns:
samples: a genotype sampled in the repertoire
random_key: an updated jax PRNG random key
"""
def _sample(
random_key: RNGKey,
genotypes: Genotype,
fitnesses: Fitness,
) -> Tuple[Genotype, RNGKey]:
"""Sample uniformly from the 2 highest fitness cells."""
max_fitnesses, _ = jax.lax.top_k(
fitnesses, self._config.exploit_num_cell_sample
)
min_fitness = jnp.nanmin(
jnp.where(max_fitnesses > -jnp.inf, max_fitnesses, jnp.inf)
)
genotypes_empty = fitnesses < min_fitness
p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
genotypes,
)
return samples, random_key
random_key, subkey = jax.random.split(random_key)
# Sample p uniformly
p = jax.random.uniform(subkey)
# Depending on the value of p, use one of the two sampling options
repertoire_sample = partial(
_sample, genotypes=repertoire.genotypes, fitnesses=repertoire.fitnesses
)
last_updated_sample = partial(
_sample,
genotypes=emitter_state.last_updated_genotypes,
fitnesses=emitter_state.last_updated_fitnesses,
)
samples, random_key = jax.lax.cond(
p < 0.5,
repertoire_sample,
last_updated_sample,
random_key,
)
return samples, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _sample_explore(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Sample uniformly from the explore_num_cell_sample most-novel genotypes.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
random_key: a jax PRNG random key
Returns:
samples: a genotype sampled in the repertoire
random_key: an updated jax PRNG random key
"""
# Compute the novelty of all indivs in the archive
novelties = emitter_state.novelty_archive.novelty(
repertoire.descriptors, self._config.novelty_nearest_neighbors
)
novelties = jnp.where(repertoire.fitnesses > -jnp.inf, novelties, -jnp.inf)
# Sample uniformly for the explore_num_cell_sample most novel cells
max_novelties, _ = jax.lax.top_k(
novelties, self._config.explore_num_cell_sample
)
min_novelty = jnp.nanmin(
jnp.where(max_novelties > -jnp.inf, max_novelties, jnp.inf)
)
repertoire_empty = novelties < min_novelty
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
repertoire.genotypes,
)
return samples, random_key
@partial(
jax.jit,
static_argnames=("self", "scores_fn"),
)
def _es_emitter(
self,
parent: Genotype,
optimizer_state: optax.OptState,
random_key: RNGKey,
scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray],
) -> Tuple[Genotype, optax.OptState, RNGKey]:
"""Main es component, given a parent and a way to infer the score from
the fitnesses and descriptors fo its es-samples, return its
approximated-gradient-generated offspring.
Args:
parent: the considered parent.
scores_fn: a function to infer the score of its es-samples from
their fitness and descriptors.
random_key
Returns:
The approximated-gradients-generated offspring and a new random_key.
"""
random_key, subkey = jax.random.split(random_key)
# Sampling mirror noise
total_sample_number = self._config.sample_number
if self._config.sample_mirror:
sample_number = total_sample_number // 2
half_sample_noise = jax.tree_util.tree_map(
lambda x: jax.random.normal(
key=subkey,
shape=jnp.repeat(x, sample_number, axis=0).shape,
),
parent,
)
sample_noise = jax.tree_util.tree_map(
lambda x: jnp.concatenate(
[jnp.expand_dims(x, axis=1), jnp.expand_dims(-x, axis=1)], axis=1
).reshape(jnp.repeat(x, 2, axis=0).shape),
half_sample_noise,
)
gradient_noise = half_sample_noise
# Sampling non-mirror noise
else:
sample_number = total_sample_number
sample_noise = jax.tree_map(
lambda x: jax.random.normal(
key=subkey,
shape=jnp.repeat(x, sample_number, axis=0).shape,
),
parent,
)
gradient_noise = sample_noise
# Applying noise
samples = jax.tree_map(
lambda x: jnp.repeat(x, total_sample_number, axis=0),
parent,
)
samples = jax.tree_map(
lambda mean, noise: mean + self._config.sample_sigma * noise,
samples,
sample_noise,
)
# Evaluating samples
fitnesses, descriptors, extra_scores, random_key = self._scoring_fn(
samples, random_key
)
# Computing rank, with or without normalisation
scores = scores_fn(fitnesses, descriptors)
if self._config.sample_rank_norm:
ranking_indices = jnp.argsort(scores, axis=0)
ranks = jnp.argsort(ranking_indices, axis=0)
ranks = (ranks / (total_sample_number - 1)) - 0.5
else:
ranks = scores
# Reshaping rank to match shape of genotype_noise
if self._config.sample_mirror:
ranks = jnp.reshape(ranks, (sample_number, 2))
ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks)
ranks = jax.tree_map(
lambda x: jnp.reshape(
jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape
),
gradient_noise,
)
# Computing the gradients
gradient = jax.tree_map(
lambda noise, rank: jnp.multiply(noise, rank),
gradient_noise,
ranks,
)
gradient = jax.tree_map(
lambda x: jnp.reshape(x, (sample_number, -1)),
gradient,
)
gradient = jax.tree_map(
lambda g, p: jnp.reshape(
-jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma),
p.shape,
),
gradient,
parent,
)
# Adding regularisation
gradient = jax.tree_map(
lambda g, p: g + self._config.l2_coefficient * p,
gradient,
parent,
)
# Applying gradients
(offspring_update, optimizer_state) = self._optimizer.update(
gradient, optimizer_state
)
offspring = optax.apply_updates(parent, offspring_update)
return offspring, optimizer_state, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _buffers_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
) -> MEESEmitterState:
"""Update the different buffers and archives in the emitter
state to generate the offspring for the next generation.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
Returns:
The modified emitter state.
"""
# Updating novelty archive
novelty_archive = emitter_state.novelty_archive.update(descriptors)
# Check if genotype from previous iteration has been added to the grid
indice = get_cells_indices(descriptors, repertoire.centroids)
added_genotype = jnp.all(
jnp.asarray(
jax.tree_util.tree_leaves(
jax.tree_util.tree_map(
lambda new_gen, rep_gen: jnp.all(
jnp.equal(
jnp.ravel(new_gen), jnp.ravel(rep_gen.at[indice].get())
),
axis=0,
),
genotypes,
repertoire.genotypes,
),
)
),
axis=0,
)
# Update last_updated buffers
last_updated_position = jnp.where(
added_genotype,
emitter_state.last_updated_position,
self._config.last_updated_size + 1,
)
last_updated_fitnesses = emitter_state.last_updated_fitnesses
last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set(
fitnesses[0]
)
last_updated_genotypes = jax.tree_map(
lambda last_gen, gen: last_gen.at[
jnp.expand_dims(last_updated_position, axis=0)
].set(gen),
emitter_state.last_updated_genotypes,
genotypes,
)
last_updated_position = (
emitter_state.last_updated_position + added_genotype
) % self._config.last_updated_size
# Return new emitter_state
return emitter_state.replace( # type: ignore
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=last_updated_position,
)
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> MEESEmitterState:
"""Generate the gradient offspring for the next emitter call. Also
update the novelty archive and generation count from current call.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
"ERROR: MAP-Elites-ES generates 1 offspring per generation, "
+ "batch_size should be 1, the inputed batch has size:"
+ str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
)
# Update all the buffers and archives of the emitter_state
emitter_state = self._buffers_update(
emitter_state, repertoire, genotypes, fitnesses, descriptors
)
# Use new or previous parents and exploitation or exploration
generation_count = emitter_state.generation_count
sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
use_exploration = (
self._config.use_explore and not self._config.use_exploit
) or (
self._config.use_explore
and self._config.use_exploit
and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
)
# Select parent and optimizer_state
parent, random_key = jax.lax.cond(
sample_new_parent,
lambda emitter_state, repertoire, random_key: jax.lax.cond(
use_exploration,
self._sample_explore,
self._sample_exploit,
emitter_state,
repertoire,
random_key,
),
lambda emitter_state, repertoire, random_key: (
emitter_state.offspring,
random_key,
),
emitter_state,
repertoire,
emitter_state.random_key,
)
optimizer_state = jax.lax.cond(
sample_new_parent,
lambda _unused: emitter_state.initial_optimizer_state,
lambda _unused: emitter_state.optimizer_state,
(),
)
# Define scores for es process
def exploration_exploitation_scores(
fitnesses: Fitness, descriptors: Descriptor
) -> jnp.ndarray:
scores = jax.lax.cond(
use_exploration,
lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
descriptors, self._config.novelty_nearest_neighbors
),
lambda fitnesses, descriptors: fitnesses,
fitnesses,
descriptors,
)
return scores
# Run es process
offspring, optimizer_state, random_key = self._es_emitter(
parent=parent,
optimizer_state=optimizer_state,
random_key=random_key,
scores_fn=exploration_exploitation_scores,
)
return emitter_state.replace( # type: ignore
optimizer_state=optimizer_state,
offspring=offspring,
generation_count=generation_count + 1,
random_key=random_key,
)
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, config, total_generations, scoring_fn, num_descriptors)
special
¶
Initialise the MAP-Elites-ES emitter. WARNING: total_generations is required to build the novelty archive.
Parameters: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
def __init__(
self,
config: MEESConfig,
total_generations: int,
scoring_fn: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
num_descriptors: int,
) -> None:
"""Initialise the MAP-Elites-ES emitter.
WARNING: total_generations is required to build the novelty archive.
Args:
config: algorithm config
scoring_fn: used to evaluate the samples for the gradient estimate.
total_generations: total number of generations for which the
emitter will run, allow to initialise the novelty archive.
num_descriptors: dimension of the descriptors, used to initialise
the empty novelty archive.
"""
self._config = config
self._scoring_fn = scoring_fn
self._total_generations = total_generations
self._num_descriptors = num_descriptors
# Initialise optimizer
if self._config.adam_optimizer:
self._optimizer = optax.adam(learning_rate=config.learning_rate)
else:
self._optimizer = optax.sgd(learning_rate=config.learning_rate)
init(self, init_genotypes, random_key)
¶
Initializes the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
)
# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
# Create empty Novelty archive
if self._config.use_explore:
novelty_archive = NoveltyArchive.init(
self._total_generations, self._num_descriptors
)
else:
novelty_archive = NoveltyArchive.init(
self._config.novelty_nearest_neighbors, self._num_descriptors
)
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
)
return (
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=0,
random_key=random_key,
),
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
Return the offspring generated through gradient update.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: MEESEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Return the offspring generated through gradient update.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state
random_key: a jax PRNG random key
Returns:
a new gradient offspring
a new jax PRNG key
"""
return emitter_state.offspring, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
Generate the gradient offspring for the next emitter call. Also update the novelty archive and generation count from current call.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> MEESEmitterState:
"""Generate the gradient offspring for the next emitter call. Also
update the novelty archive and generation count from current call.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
"ERROR: MAP-Elites-ES generates 1 offspring per generation, "
+ "batch_size should be 1, the inputed batch has size:"
+ str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
)
# Update all the buffers and archives of the emitter_state
emitter_state = self._buffers_update(
emitter_state, repertoire, genotypes, fitnesses, descriptors
)
# Use new or previous parents and exploitation or exploration
generation_count = emitter_state.generation_count
sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
use_exploration = (
self._config.use_explore and not self._config.use_exploit
) or (
self._config.use_explore
and self._config.use_exploit
and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
)
# Select parent and optimizer_state
parent, random_key = jax.lax.cond(
sample_new_parent,
lambda emitter_state, repertoire, random_key: jax.lax.cond(
use_exploration,
self._sample_explore,
self._sample_exploit,
emitter_state,
repertoire,
random_key,
),
lambda emitter_state, repertoire, random_key: (
emitter_state.offspring,
random_key,
),
emitter_state,
repertoire,
emitter_state.random_key,
)
optimizer_state = jax.lax.cond(
sample_new_parent,
lambda _unused: emitter_state.initial_optimizer_state,
lambda _unused: emitter_state.optimizer_state,
(),
)
# Define scores for es process
def exploration_exploitation_scores(
fitnesses: Fitness, descriptors: Descriptor
) -> jnp.ndarray:
scores = jax.lax.cond(
use_exploration,
lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
descriptors, self._config.novelty_nearest_neighbors
),
lambda fitnesses, descriptors: fitnesses,
fitnesses,
descriptors,
)
return scores
# Run es process
offspring, optimizer_state, random_key = self._es_emitter(
parent=parent,
optimizer_state=optimizer_state,
random_key=random_key,
scores_fn=exploration_exploitation_scores,
)
return emitter_state.replace( # type: ignore
optimizer_state=optimizer_state,
offspring=offspring,
generation_count=generation_count + 1,
random_key=random_key,
)