Emitters¶
qdax.core.emitters
special
¶
cma_emitter
¶
CMAEmitterState (EmitterState)
dataclass
¶
Emitter state for the CMA-ME emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_emitter.py
class CMAEmitterState(EmitterState):
"""
Emitter state for the CMA-ME emitter.
Args:
random_key: a random key to handle stochastic operations. Used for
state update only, another key is used to emit. This might be
subject to refactoring discussions in the future.
cmaes_state: state of the underlying CMA-ES algorithm
previous_fitnesses: store last fitnesses of the repertoire. Used to
compute the improvment.
emit_count: count the number of emission events.
"""
random_key: RNGKey
cmaes_state: CMAESState
previous_fitnesses: Fitness
emit_count: int
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/cma_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
CMAEmitter (Emitter, ABC)
¶
Source code in qdax/core/emitters/cma_emitter.py
class CMAEmitter(Emitter, ABC):
def __init__(
self,
batch_size: int,
genotype_dim: int,
centroids: Centroid,
sigma_g: float,
min_count: Optional[int] = None,
max_count: Optional[float] = None,
):
"""
Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
Rapid Illumination of Behavior Space" by Fontaine et al.
Args:
batch_size: number of solutions sampled at each iteration
genotype_dim: dimension of the genotype space.
centroids: centroids used for the repertoire.
sigma_g: standard deviation for the coefficients - called step size.
min_count: minimum number of CMAES opt step before being considered for
reinitialisation.
max_count: maximum number of CMAES opt step authorized.
"""
self._batch_size = batch_size
# define a CMAES instance
self._cmaes = CMAES(
population_size=batch_size,
search_dim=genotype_dim,
# no need for fitness function in that specific case
fitness_function=None, # type: ignore
num_best=batch_size,
init_sigma=sigma_g,
mean_init=None, # will be init at zeros in cmaes
bias_weights=True,
delay_eigen_decomposition=True,
)
# minimum number of emitted solution before an emitter can be re-initialized
if min_count is None:
min_count = 0
self._min_count = min_count
if max_count is None:
max_count = jnp.inf
self._max_count = max_count
self._centroids = centroids
self._cma_initial_state = self._cmaes.init()
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._batch_size
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMAEmitterState(
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
emit_count=0,
),
random_key,
)
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals. Interestingly, this method does not directly modifies
individuals from the repertoire but sample from a distribution. Hence the
repertoire is not used in the emit function.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# emit from CMA-ES
offsprings, random_key = self._cmaes.sample(
cmaes_state=emitter_state.cmaes_state, random_key=random_key
)
return offsprings, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the CMA-ME emitter state.
Note: we use the update_state function from CMAES, a function that assumes
that the candidates are already sorted. We do this because we have to sort
them in this function anyway, in order to apply the right weights to the
terms when update theta.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve elements from the emitter state
cmaes_state = emitter_state.cmaes_state
# Compute the improvements - needed for re-init condition
indices = get_cells_indices(descriptors, repertoire.centroids)
improvements = fitnesses - emitter_state.previous_fitnesses[indices]
ranking_criteria = self._ranking_criteria(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
improvements=improvements,
)
# get the indices
sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))
# sort the candidates
sorted_candidates = jax.tree_util.tree_map(
lambda x: x[sorted_indices], genotypes
)
sorted_improvements = improvements[sorted_indices]
# compute reinitialize condition
emit_count = emitter_state.emit_count + 1
# check if the criteria are too similar
sorted_criteria = ranking_criteria[sorted_indices]
flat_criteria_condition = (
jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
)
# check all conditions
reinitialize = (
jnp.all(improvements < 0) * (emit_count > self._min_count)
+ (emit_count > self._max_count)
+ self._cmaes.stop_condition(cmaes_state)
+ flat_criteria_condition
)
# If true, draw randomly and re-initialize parameters
def update_and_reinit(
operand: Tuple[
CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
) -> Tuple[CMAEmitterState, RNGKey]:
return self._update_and_init_emitter_state(*operand)
def update_wo_reinit(
operand: Tuple[
CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
) -> Tuple[CMAEmitterState, RNGKey]:
"""Update the emitter when no reinit event happened.
Here lies a divergence compared to the original implementation. We
are getting better results when using no mask and doing the update
with the whole batch of individuals rather than keeping only the one
than were added to the archive.
Interestingly, keeping the best half was not doing better. We think that
this might be due to the small batch size used.
This applies for the setting from the paper CMA-ME. Those facts might
not be true with other problems and hyperparameters.
To replicate the code described in the paper, replace:
`mask = jnp.ones_like(sorted_improvements)`
by:
```
mask = sorted_improvements >= 0
mask = mask + 1e-6
```
RMQ: the addition of 1e-6 is here to fix a numerical
instability.
"""
(cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand
# Update CMA Parameters
mask = jnp.ones_like(sorted_improvements)
cmaes_state = self._cmaes.update_state_with_mask(
cmaes_state, sorted_candidates, mask=mask
)
emitter_state = emitter_state.replace(
cmaes_state=cmaes_state,
emit_count=emit_count,
)
return emitter_state, random_key
# Update CMA Parameters
emitter_state, random_key = jax.lax.cond(
reinitialize,
update_and_reinit,
update_wo_reinit,
operand=(
cmaes_state,
emitter_state,
repertoire,
emit_count,
emitter_state.random_key,
),
)
# update the emitter state
emitter_state = emitter_state.replace(
random_key=random_key, previous_fitnesses=repertoire.fitnesses
)
return emitter_state
def _update_and_init_emitter_state(
self,
cmaes_state: CMAESState,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
emit_count: int,
random_key: RNGKey,
) -> Tuple[CMAEmitterState, RNGKey]:
"""Update the emitter state in the case of a reinit event.
Reinit the cmaes state and use an individual from the repertoire
as the starting mean.
Args:
cmaes_state: current cmaes state
emitter_state: current cmame state
repertoire: most recent repertoire
emit_count: counter of the emitter
random_key: key to handle stochastic events
Returns:
The updated emitter state.
"""
# re-sample
random_genotype, random_key = repertoire.sample(random_key, 1)
# remove the batch dim
new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)
cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)
emitter_state = emitter_state.replace(
cmaes_state=cmaes_init_state, emit_count=0
)
return emitter_state, random_key
@abstractmethod
def _ranking_criteria(
self,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores],
improvements: jnp.ndarray,
) -> jnp.ndarray:
"""Defines how the genotypes should be sorted. Impacts the update
of the CMAES state. In the end, this defines the type of CMAES emitter
used (optimizing, random direction or improvement).
Args:
emitter_state: current state of the emitter.
repertoire: latest repertoire of genotypes.
genotypes: emitted genotypes.
fitnesses: corresponding fitnesses.
descriptors: corresponding fitnesses.
extra_scores: corresponding extra scores.
improvements: improvments of the emitted genotypes. This corresponds
to the difference between their fitness and the fitness of the
individual occupying the cell of corresponding fitness.
Returns:
The values to take into account in order to rank the emitted genotypes.
Here, it's the improvement, or the fitness when the cell was previously
unoccupied. Additionally, genotypes that discovered a new cell are
given on offset to be ranked in front of other genotypes.
"""
pass
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, batch_size, genotype_dim, centroids, sigma_g, min_count=None, max_count=None)
special
¶
Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Behavior Space" by Fontaine et al.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_emitter.py
def __init__(
self,
batch_size: int,
genotype_dim: int,
centroids: Centroid,
sigma_g: float,
min_count: Optional[int] = None,
max_count: Optional[float] = None,
):
"""
Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
Rapid Illumination of Behavior Space" by Fontaine et al.
Args:
batch_size: number of solutions sampled at each iteration
genotype_dim: dimension of the genotype space.
centroids: centroids used for the repertoire.
sigma_g: standard deviation for the coefficients - called step size.
min_count: minimum number of CMAES opt step before being considered for
reinitialisation.
max_count: maximum number of CMAES opt step authorized.
"""
self._batch_size = batch_size
# define a CMAES instance
self._cmaes = CMAES(
population_size=batch_size,
search_dim=genotype_dim,
# no need for fitness function in that specific case
fitness_function=None, # type: ignore
num_best=batch_size,
init_sigma=sigma_g,
mean_init=None, # will be init at zeros in cmaes
bias_weights=True,
delay_eigen_decomposition=True,
)
# minimum number of emitted solution before an emitter can be re-initialized
if min_count is None:
min_count = 0
self._min_count = min_count
if max_count is None:
max_count = jnp.inf
self._max_count = max_count
self._centroids = centroids
self._cma_initial_state = self._cmaes.init()
init(self, init_genotypes, random_key)
¶
Initializes the CMA-MEGA emitter
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMAEmitterState(
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
emit_count=0,
),
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals. Interestingly, this method does not directly modifies
individuals from the repertoire but sample from a distribution. Hence the
repertoire is not used in the emit function.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# emit from CMA-ES
offsprings, random_key = self._cmaes.sample(
cmaes_state=emitter_state.cmaes_state, random_key=random_key
)
return offsprings, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)
¶
Updates the CMA-ME emitter state.
Note: we use the update_state function from CMAES, a function that assumes that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the CMA-ME emitter state.
Note: we use the update_state function from CMAES, a function that assumes
that the candidates are already sorted. We do this because we have to sort
them in this function anyway, in order to apply the right weights to the
terms when update theta.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve elements from the emitter state
cmaes_state = emitter_state.cmaes_state
# Compute the improvements - needed for re-init condition
indices = get_cells_indices(descriptors, repertoire.centroids)
improvements = fitnesses - emitter_state.previous_fitnesses[indices]
ranking_criteria = self._ranking_criteria(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
improvements=improvements,
)
# get the indices
sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))
# sort the candidates
sorted_candidates = jax.tree_util.tree_map(
lambda x: x[sorted_indices], genotypes
)
sorted_improvements = improvements[sorted_indices]
# compute reinitialize condition
emit_count = emitter_state.emit_count + 1
# check if the criteria are too similar
sorted_criteria = ranking_criteria[sorted_indices]
flat_criteria_condition = (
jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
)
# check all conditions
reinitialize = (
jnp.all(improvements < 0) * (emit_count > self._min_count)
+ (emit_count > self._max_count)
+ self._cmaes.stop_condition(cmaes_state)
+ flat_criteria_condition
)
# If true, draw randomly and re-initialize parameters
def update_and_reinit(
operand: Tuple[
CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
) -> Tuple[CMAEmitterState, RNGKey]:
return self._update_and_init_emitter_state(*operand)
def update_wo_reinit(
operand: Tuple[
CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
],
) -> Tuple[CMAEmitterState, RNGKey]:
"""Update the emitter when no reinit event happened.
Here lies a divergence compared to the original implementation. We
are getting better results when using no mask and doing the update
with the whole batch of individuals rather than keeping only the one
than were added to the archive.
Interestingly, keeping the best half was not doing better. We think that
this might be due to the small batch size used.
This applies for the setting from the paper CMA-ME. Those facts might
not be true with other problems and hyperparameters.
To replicate the code described in the paper, replace:
`mask = jnp.ones_like(sorted_improvements)`
by:
```
mask = sorted_improvements >= 0
mask = mask + 1e-6
```
RMQ: the addition of 1e-6 is here to fix a numerical
instability.
"""
(cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand
# Update CMA Parameters
mask = jnp.ones_like(sorted_improvements)
cmaes_state = self._cmaes.update_state_with_mask(
cmaes_state, sorted_candidates, mask=mask
)
emitter_state = emitter_state.replace(
cmaes_state=cmaes_state,
emit_count=emit_count,
)
return emitter_state, random_key
# Update CMA Parameters
emitter_state, random_key = jax.lax.cond(
reinitialize,
update_and_reinit,
update_wo_reinit,
operand=(
cmaes_state,
emitter_state,
repertoire,
emit_count,
emitter_state.random_key,
),
)
# update the emitter state
emitter_state = emitter_state.replace(
random_key=random_key, previous_fitnesses=repertoire.fitnesses
)
return emitter_state
cma_improvement_emitter
¶
CMAImprovementEmitter (CMAEmitter)
¶
Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Behavior Space" by Fontaine et al.
This class implements the improvement emitter, where the update of the distribution is biased towards solution that improve the QD score.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_improvement_emitter.py
class CMAImprovementEmitter(CMAEmitter):
"""Class for the emitter of CMA ME from "Covariance Matrix Adaptation
for the Rapid Illumination of Behavior Space" by Fontaine et al.
This class implements the improvement emitter, where the update of the
distribution is biased towards solution that improve the QD score.
Args:
batch_size: number of solutions sampled at each iteration
genotype_dim: dimension of the genotype space.
centroids: centroids used for the repertoire.
sigma_g: standard deviation for the coefficients - called step size.
min_count: minimum number of CMAES opt step before being considered for
reinitialisation.
max_count: maximum number of CMAES opt step authorized.
"""
def _ranking_criteria(
self,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores],
improvements: jnp.ndarray,
) -> jnp.ndarray:
"""Defines how the genotypes should be sorted. Impacts the update
of the CMAES state. In the end, this defines the type of CMAES emitter
used (optimizing, random direction or improvement).
Args:
emitter_state: current state of the emitter.
repertoire: latest repertoire of genotypes.
genotypes: emitted genotypes.
fitnesses: corresponding fitnesses.
descriptors: corresponding fitnesses.
extra_scores: corresponding extra scores.
improvements: improvments of the emitted genotypes. This corresponds
to the difference between their fitness and the fitness of the
individual occupying the cell of corresponding fitness.
Returns:
The values to take into account in order to rank the emitted genotypes.
Here, it's the improvement, or the fitness when the cell was previously
unoccupied. Additionally, genotypes that discovered a new cell are
given on offset to be ranked in front of other genotypes.
"""
# condition for being a new cell
condition = improvements == jnp.inf
# criteria: fitness if new cell, improvement else
ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)
# make sure to have all the new cells first
new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)
ranking_criteria = jnp.where(
condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
)
return ranking_criteria # type: ignore
cma_mega_emitter
¶
CMAMEGAState (EmitterState)
dataclass
¶
Emitter state for the CMA-MEGA emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_mega_emitter.py
class CMAMEGAState(EmitterState):
"""
Emitter state for the CMA-MEGA emitter.
Args:
theta: current genotype from where candidates will be drawn.
theta_grads: normalized fitness and descriptors gradients of theta.
random_key: a random key to handle stochastic operations. Used for
state update only, another key is used to emit. This might be
subject to refactoring discussions in the future.
cmaes_state: state of the underlying CMA-ES algorithm
previous_fitnesses: store last fitnesses of the repertoire. Used to
compute the improvment.
"""
theta: Genotype
theta_grads: Gradient
random_key: RNGKey
cmaes_state: CMAESState
previous_fitnesses: Fitness
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/cma_mega_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
CMAMEGAEmitter (Emitter)
¶
Source code in qdax/core/emitters/cma_mega_emitter.py
class CMAMEGAEmitter(Emitter):
def __init__(
self,
scoring_function: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
batch_size: int,
learning_rate: float,
num_descriptors: int,
centroids: Centroid,
sigma_g: float,
):
"""
Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
Fontaine et al.
Args:
scoring_function: a function to score individuals, outputing fitness,
descriptors and extra scores. With this emitter, the extra score
contains gradients and normalized gradients.
batch_size: number of solutions sampled at each iteration
learning_rate: rate at which the mean of the distribution is updated.
num_descriptors: number of descriptors
centroids: centroids of the repertoire used to store the genotypes
sigma_g: standard deviation for the coefficients
"""
self._scoring_function = scoring_function
self._batch_size = batch_size
self._learning_rate = learning_rate
# weights used to update the gradient direction through a linear combination
self._weights = jnp.expand_dims(
jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
)
self._weights = self._weights / (self._weights.sum())
# define a CMAES instance - used to update the coeffs
self._cmaes = CMAES(
population_size=batch_size,
search_dim=num_descriptors + 1,
# no need for fitness function in that specific case
fitness_function=None, # type: ignore
num_best=batch_size,
init_sigma=sigma_g,
bias_weights=True,
delay_eigen_decomposition=True,
)
self._centroids = centroids
self._cma_initial_state = self._cmaes.init()
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAMEGAState, RNGKey]:
"""
Initializes the CMA-MEGA emitter.
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# define init theta as 0
theta = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x[:1, ...]),
init_genotypes,
)
# score it
_, _, extra_score, random_key = self._scoring_function(theta, random_key)
theta_grads = extra_score["normalized_grads"]
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMAMEGAState(
theta=theta,
theta_grads=theta_grads,
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
),
random_key,
)
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAMEGAState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals. Interestingly, this method does not directly modifies
individuals from the repertoire but sample from a distribution. Hence the
repertoire is not used in the emit function.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# retrieve elements from the emitter state
theta = jnp.nan_to_num(emitter_state.theta)
cmaes_state = emitter_state.cmaes_state
# get grads - remove nan and first dimension
grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))
# Draw random coefficients - use the emitter state key
coeffs, random_key = self._cmaes.sample(
cmaes_state=cmaes_state, random_key=emitter_state.random_key
)
# make sure the fitness coefficient is positive
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
update_grad = coeffs @ grads.T
# Compute new candidates
new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)
return new_thetas, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAMEGAState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the CMA-MEGA emitter state.
Note: in order to recover the coeffs that where used to sample the genotypes,
we reuse the emitter state's random key in this function.
Note: we use the update_state function from CMAES, a function that suppose
that the candidates are already sorted. We do this because we have to sort
them in this function anyway, in order to apply the right weights to the
terms when update theta.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve elements from the emitter state
cmaes_state = emitter_state.cmaes_state
theta = jnp.nan_to_num(emitter_state.theta)
grads = jnp.nan_to_num(emitter_state.theta_grads[0])
# Update the archive and compute the improvements
indices = get_cells_indices(descriptors, repertoire.centroids)
improvements = fitnesses - emitter_state.previous_fitnesses[indices]
# condition for being a new cell
condition = improvements == jnp.inf
# criteria: fitness if new cell, improvement else
ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)
# make sure to have all the new cells first
new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)
ranking_criteria = jnp.where(
condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
)
# sort indices according to the criteria
sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))
# Draw the coeffs - reuse the emitter state key to get same coeffs
coeffs, random_key = self._cmaes.sample(
cmaes_state=cmaes_state, random_key=emitter_state.random_key
)
# make sure the fitness coeff is positive
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
# get the gradients that must be applied
update_grad = coeffs @ grads.T
# weight terms - based on improvement rank
gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)
# update theta
theta = jax.tree_util.tree_map(
lambda x, y: x + self._learning_rate * y, theta, gradient_step
)
# Update CMA Parameters
sorted_candidates = coeffs[sorted_indices]
cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)
# If no improvement draw randomly and re-initialize parameters
reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
cmaes_state
)
# re-sample
random_theta, random_key = repertoire.sample(random_key, 1)
# update theta in case of reinit
theta = jax.tree_util.tree_map(
lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta
)
# update cmaes state in case of reinit
cmaes_state = jax.tree_util.tree_map(
lambda x, y: jnp.where(reinitialize, x=x, y=y),
self._cma_initial_state,
cmaes_state,
)
# score theta
_, _, extra_score, random_key = self._scoring_function(theta, random_key)
# create new emitter state
emitter_state = CMAMEGAState(
theta=theta,
theta_grads=extra_score["normalized_grads"],
random_key=random_key,
cmaes_state=cmaes_state,
previous_fitnesses=repertoire.fitnesses,
)
return emitter_state
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._batch_size
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, scoring_function, batch_size, learning_rate, num_descriptors, centroids, sigma_g)
special
¶
Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by Fontaine et al.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_mega_emitter.py
def __init__(
self,
scoring_function: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
batch_size: int,
learning_rate: float,
num_descriptors: int,
centroids: Centroid,
sigma_g: float,
):
"""
Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
Fontaine et al.
Args:
scoring_function: a function to score individuals, outputing fitness,
descriptors and extra scores. With this emitter, the extra score
contains gradients and normalized gradients.
batch_size: number of solutions sampled at each iteration
learning_rate: rate at which the mean of the distribution is updated.
num_descriptors: number of descriptors
centroids: centroids of the repertoire used to store the genotypes
sigma_g: standard deviation for the coefficients
"""
self._scoring_function = scoring_function
self._batch_size = batch_size
self._learning_rate = learning_rate
# weights used to update the gradient direction through a linear combination
self._weights = jnp.expand_dims(
jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
)
self._weights = self._weights / (self._weights.sum())
# define a CMAES instance - used to update the coeffs
self._cmaes = CMAES(
population_size=batch_size,
search_dim=num_descriptors + 1,
# no need for fitness function in that specific case
fitness_function=None, # type: ignore
num_best=batch_size,
init_sigma=sigma_g,
bias_weights=True,
delay_eigen_decomposition=True,
)
self._centroids = centroids
self._cma_initial_state = self._cmaes.init()
init(self, init_genotypes, random_key)
¶
Initializes the CMA-MEGA emitter.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAMEGAState, RNGKey]:
"""
Initializes the CMA-MEGA emitter.
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# define init theta as 0
theta = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x[:1, ...]),
init_genotypes,
)
# score it
_, _, extra_score, random_key = self._scoring_function(theta, random_key)
theta_grads = extra_score["normalized_grads"]
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMAMEGAState(
theta=theta,
theta_grads=theta_grads,
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
),
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAMEGAState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals. Interestingly, this method does not directly modifies
individuals from the repertoire but sample from a distribution. Hence the
repertoire is not used in the emit function.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# retrieve elements from the emitter state
theta = jnp.nan_to_num(emitter_state.theta)
cmaes_state = emitter_state.cmaes_state
# get grads - remove nan and first dimension
grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))
# Draw random coefficients - use the emitter state key
coeffs, random_key = self._cmaes.sample(
cmaes_state=cmaes_state, random_key=emitter_state.random_key
)
# make sure the fitness coefficient is positive
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
update_grad = coeffs @ grads.T
# Compute new candidates
new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)
return new_thetas, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)
¶
Updates the CMA-MEGA emitter state.
Note: in order to recover the coeffs that where used to sample the genotypes, we reuse the emitter state's random key in this function.
Note: we use the update_state function from CMAES, a function that suppose that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAMEGAState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the CMA-MEGA emitter state.
Note: in order to recover the coeffs that where used to sample the genotypes,
we reuse the emitter state's random key in this function.
Note: we use the update_state function from CMAES, a function that suppose
that the candidates are already sorted. We do this because we have to sort
them in this function anyway, in order to apply the right weights to the
terms when update theta.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve elements from the emitter state
cmaes_state = emitter_state.cmaes_state
theta = jnp.nan_to_num(emitter_state.theta)
grads = jnp.nan_to_num(emitter_state.theta_grads[0])
# Update the archive and compute the improvements
indices = get_cells_indices(descriptors, repertoire.centroids)
improvements = fitnesses - emitter_state.previous_fitnesses[indices]
# condition for being a new cell
condition = improvements == jnp.inf
# criteria: fitness if new cell, improvement else
ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)
# make sure to have all the new cells first
new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)
ranking_criteria = jnp.where(
condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
)
# sort indices according to the criteria
sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))
# Draw the coeffs - reuse the emitter state key to get same coeffs
coeffs, random_key = self._cmaes.sample(
cmaes_state=cmaes_state, random_key=emitter_state.random_key
)
# make sure the fitness coeff is positive
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
# get the gradients that must be applied
update_grad = coeffs @ grads.T
# weight terms - based on improvement rank
gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)
# update theta
theta = jax.tree_util.tree_map(
lambda x, y: x + self._learning_rate * y, theta, gradient_step
)
# Update CMA Parameters
sorted_candidates = coeffs[sorted_indices]
cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)
# If no improvement draw randomly and re-initialize parameters
reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
cmaes_state
)
# re-sample
random_theta, random_key = repertoire.sample(random_key, 1)
# update theta in case of reinit
theta = jax.tree_util.tree_map(
lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta
)
# update cmaes state in case of reinit
cmaes_state = jax.tree_util.tree_map(
lambda x, y: jnp.where(reinitialize, x=x, y=y),
self._cma_initial_state,
cmaes_state,
)
# score theta
_, _, extra_score, random_key = self._scoring_function(theta, random_key)
# create new emitter state
emitter_state = CMAMEGAState(
theta=theta,
theta_grads=extra_score["normalized_grads"],
random_key=random_key,
cmaes_state=cmaes_state,
previous_fitnesses=repertoire.fitnesses,
)
return emitter_state
cma_pool_emitter
¶
CMAPoolEmitterState (EmitterState)
dataclass
¶
Emitter state for the pool of CMA emitters.
This is for a pool of homogeneous emitters.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_pool_emitter.py
class CMAPoolEmitterState(EmitterState):
"""
Emitter state for the pool of CMA emitters.
This is for a pool of homogeneous emitters.
Args:
current_index: the index of the current emitter state used.
emitter_states: the batch of emitter states currently used.
"""
current_index: int
emitter_states: CMAEmitterState
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/cma_pool_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
CMAPoolEmitter (Emitter)
¶
Source code in qdax/core/emitters/cma_pool_emitter.py
class CMAPoolEmitter(Emitter):
def __init__(self, num_states: int, emitter: CMAEmitter):
"""Instantiate a pool of homogeneous emitters.
Args:
num_states: the number of emitters to consider. We can use a
single emitter object and a batched emitter state.
emitter: the type of emitter for the pool.
"""
self._num_states = num_states
self._emitter = emitter
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._emitter.batch_size
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAPoolEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
random_key = carry
emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
return random_key, emitter_state
# init all the emitter states
random_key, emitter_states = jax.lax.scan(
scan_emitter_init, random_key, (), length=self._num_states
)
# define the emitter state of the pool
emitter_state = CMAPoolEmitterState(
current_index=0, emitter_states=emitter_states
)
return (
emitter_state,
random_key,
)
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAPoolEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# retrieve the relevant emitter state
current_index = emitter_state.current_index
used_emitter_state = jax.tree_util.tree_map(
lambda x: x[current_index], emitter_state.emitter_states
)
# use it to emit offsprings
offsprings, random_key = self._emitter.emit(
repertoire, used_emitter_state, random_key
)
return offsprings, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAPoolEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the emitter state.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve the emitter that has been used and it's emitter state
current_index = emitter_state.current_index
emitter_states = emitter_state.emitter_states
used_emitter_state = jax.tree_util.tree_map(
lambda x: x[current_index], emitter_states
)
# update the used emitter state
used_emitter_state = self._emitter.state_update(
emitter_state=used_emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the emitter state
emitter_states = jax.tree_util.tree_map(
lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
)
# determine the next emitter to be used
emit_counts = emitter_states.emit_count
new_index = jnp.argmin(emit_counts)
emitter_state = emitter_state.replace(
current_index=new_index, emitter_states=emitter_states
)
return emitter_state # type: ignore
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, num_states, emitter)
special
¶
Instantiate a pool of homogeneous emitters.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_pool_emitter.py
def __init__(self, num_states: int, emitter: CMAEmitter):
"""Instantiate a pool of homogeneous emitters.
Args:
num_states: the number of emitters to consider. We can use a
single emitter object and a batched emitter state.
emitter: the type of emitter for the pool.
"""
self._num_states = num_states
self._emitter = emitter
init(self, init_genotypes, random_key)
¶
Initializes the CMA-MEGA emitter
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAPoolEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
random_key = carry
emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
return random_key, emitter_state
# init all the emitter states
random_key, emitter_states = jax.lax.scan(
scan_emitter_init, random_key, (), length=self._num_states
)
# define the emitter state of the pool
emitter_state = CMAPoolEmitterState(
current_index=0, emitter_states=emitter_states
)
return (
emitter_state,
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
Emits new individuals.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[MapElitesRepertoire],
emitter_state: CMAPoolEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emits new individuals.
Args:
repertoire: a repertoire of genotypes (unused).
emitter_state: the state of the CMA-MEGA emitter.
random_key: a random key to handle random operations.
Returns:
New genotypes and a new random key.
"""
# retrieve the relevant emitter state
current_index = emitter_state.current_index
used_emitter_state = jax.tree_util.tree_map(
lambda x: x[current_index], emitter_state.emitter_states
)
# use it to emit offsprings
offsprings, random_key = self._emitter.emit(
repertoire, used_emitter_state, random_key
)
return offsprings, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)
¶
Updates the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: CMAPoolEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""
Updates the emitter state.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring (unused).
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: unused
Returns:
The updated emitter state.
"""
# retrieve the emitter that has been used and it's emitter state
current_index = emitter_state.current_index
emitter_states = emitter_state.emitter_states
used_emitter_state = jax.tree_util.tree_map(
lambda x: x[current_index], emitter_states
)
# update the used emitter state
used_emitter_state = self._emitter.state_update(
emitter_state=used_emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# update the emitter state
emitter_states = jax.tree_util.tree_map(
lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
)
# determine the next emitter to be used
emit_counts = emitter_states.emit_count
new_index = jnp.argmin(emit_counts)
emitter_state = emitter_state.replace(
current_index=new_index, emitter_states=emitter_states
)
return emitter_state # type: ignore
cma_rnd_emitter
¶
CMARndEmitterState (CMAEmitterState)
dataclass
¶
Emitter state for the CMA-ME random direction emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/cma_rnd_emitter.py
class CMARndEmitterState(CMAEmitterState):
"""
Emitter state for the CMA-ME random direction emitter.
Args:
random_key: a random key to handle stochastic operations. Used for
state update only, another key is used to emit. This might be
subject to refactoring discussions in the future.
cmaes_state: state of the underlying CMA-ES algorithm
previous_fitnesses: store last fitnesses of the repertoire. Used to
compute the improvment.
emit_count: count the number of emission events.
random_direction: direction of the behavior space we are trying to
explore.
"""
random_direction: Descriptor
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/cma_rnd_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
CMARndEmitter (CMAEmitter)
¶
Source code in qdax/core/emitters/cma_rnd_emitter.py
class CMARndEmitter(CMAEmitter):
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMARndEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# take a random direction
random_key, subkey = jax.random.split(random_key)
random_direction = jax.random.uniform(
subkey,
shape=(self._centroids.shape[-1],),
)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMARndEmitterState(
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
emit_count=0,
random_direction=random_direction,
),
random_key,
)
def _update_and_init_emitter_state(
self,
cmaes_state: CMAESState,
emitter_state: CMAEmitterState,
repertoire: MapElitesRepertoire,
emit_count: int,
random_key: RNGKey,
) -> Tuple[CMAEmitterState, RNGKey]:
"""Update the emitter state in the case of a reinit event.
Reinit the cmaes state and use an individual from the repertoire
as the starting mean.
Args:
cmaes_state: current cmaes state
emitter_state: current cmame state
repertoire: most recent repertoire
emit_count: counter of the emitter
random_key: key to handle stochastic events
Returns:
The updated emitter state.
"""
# re-sample
random_genotype, random_key = repertoire.sample(random_key, 1)
# get new mean - remove the batch dim
new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)
# define the corresponding cmaes init state
cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)
# take a new random direction
random_key, subkey = jax.random.split(random_key)
random_direction = jax.random.uniform(
subkey,
shape=(self._centroids.shape[-1],),
)
emitter_state = emitter_state.replace(
cmaes_state=cmaes_init_state,
emit_count=0,
random_direction=random_direction,
)
return emitter_state, random_key
def _ranking_criteria(
self,
emitter_state: CMARndEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: Optional[ExtraScores],
improvements: jnp.ndarray,
) -> jnp.ndarray:
"""Defines how the genotypes should be sorted. Impacts the update
of the CMAES state. In the end, this defines the type of CMAES emitter
used (optimizing, random direction or improvement).
Args:
emitter_state: current state of the emitter.
repertoire: latest repertoire of genotypes.
genotypes: emitted genotypes.
fitnesses: corresponding fitnesses.
descriptors: corresponding fitnesses.
extra_scores: corresponding extra scores.
improvements: improvments of the emitted genotypes. This corresponds
to the difference between their fitness and the fitness of the
individual occupying the cell of corresponding fitness.
Returns:
The values to take into account in order to rank the emitted genotypes.
Here, it is the dot product of the descriptor with the current random
direction.
"""
# criteria: projection of the descriptors along the random direction
ranking_criteria = jnp.dot(descriptors, emitter_state.random_direction)
# make sure to have all the new cells first
new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)
# condition for being a new cell
condition = improvements == jnp.inf
ranking_criteria = jnp.where(
condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
)
return ranking_criteria # type: ignore
init(self, init_genotypes, random_key)
¶
Initializes the CMA-MEGA emitter
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/cma_rnd_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMARndEmitterState, RNGKey]:
"""
Initializes the CMA-MEGA emitter
Args:
init_genotypes: initial genotypes to add to the grid.
random_key: a random key to handle stochastic operations.
Returns:
The initial state of the emitter.
"""
# Initialize repertoire with default values
num_centroids = self._centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# take a random direction
random_key, subkey = jax.random.split(random_key)
random_direction = jax.random.uniform(
subkey,
shape=(self._centroids.shape[-1],),
)
# return the initial state
random_key, subkey = jax.random.split(random_key)
return (
CMARndEmitterState(
random_key=subkey,
cmaes_state=self._cma_initial_state,
previous_fitnesses=default_fitnesses,
emit_count=0,
random_direction=random_direction,
),
random_key,
)
dpg_emitter
¶
Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505
DiversityPGConfig (QualityPGConfig)
dataclass
¶
Configuration for DiversityPG Emitter
Source code in qdax/core/emitters/dpg_emitter.py
@dataclass
class DiversityPGConfig(QualityPGConfig):
"""Configuration for DiversityPG Emitter"""
# inherits fields from QualityPGConfig
# Archive params
archive_acceptance_threshold: float = 0.1
archive_max_size: int = 10000
DiversityPGEmitterState (QualityPGEmitterState)
dataclass
¶
Contains training state for the learner.
Source code in qdax/core/emitters/dpg_emitter.py
class DiversityPGEmitterState(QualityPGEmitterState):
"""Contains training state for the learner."""
# inherits from QualityPGEmitterState
archive: Archive
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/dpg_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
DiversityPGEmitter (QualityPGEmitter)
¶
A diversity policy gradient emitter used to implement QDPG algorithm.
Please not that the inheritence between DiversityPGEmitter and QualityPGEmitter could be increased with changes in the way transitions samples are handled in the QualityPGEmitter. But this would modify the computation/memory strategy of the current implementation. Hence, we won't apply this yet and will discuss this with the development team.
Source code in qdax/core/emitters/dpg_emitter.py
class DiversityPGEmitter(QualityPGEmitter):
"""
A diversity policy gradient emitter used to implement QDPG algorithm.
Please not that the inheritence between DiversityPGEmitter and QualityPGEmitter
could be increased with changes in the way transitions samples are handled in
the QualityPGEmitter. But this would modify the computation/memory strategy of the
current implementation. Hence, we won't apply this yet and will discuss this with
the development team.
"""
def __init__(
self,
config: DiversityPGConfig,
policy_network: nn.Module,
env: QDEnv,
score_novelty: Callable[[Archive, StateDescriptor], Reward],
) -> None:
# usual init operations from PGAME
super().__init__(config, policy_network, env)
self._config: DiversityPGConfig = config
# define scoring function
self._score_novelty = score_novelty
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[DiversityPGEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
# init elements of diversity emitter state with QualityEmitterState.init()
diversity_emitter_state, random_key = super().init(init_genotypes, random_key)
# store elements in a dictionary
attributes_dict = vars(diversity_emitter_state)
# init archive
archive = Archive.create(
acceptance_threshold=self._config.archive_acceptance_threshold,
state_descriptor_size=self._env.state_descriptor_length,
max_size=self._config.archive_max_size,
)
# init emitter state
emitter_state = DiversityPGEmitterState(
# retrieve all attributes from the QualityPGEmitterState
**attributes_dict,
# add the last element: archive
archive=archive,
)
return emitter_state, random_key
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: DiversityPGEmitterState,
repertoire: Optional[Repertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> DiversityPGEmitterState:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions
from the scoring of the genotypes, and then the training of the
critic/actor happens. Hence the params of critic/actor are updated,
as well as their optimizer states.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# get the transitions out of the dictionary
assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
transitions = extra_scores["transitions"]
# add transitions in the replay buffer
replay_buffer = emitter_state.replay_buffer.insert(transitions)
emitter_state = emitter_state.replace(replay_buffer=replay_buffer)
archive = emitter_state.archive.insert(transitions.state_desc)
def scan_train_critics(
carry: DiversityPGEmitterState, transitions: QDTransition
) -> Tuple[DiversityPGEmitterState, Any]:
emitter_state = carry
new_emitter_state = self._train_critics(emitter_state, transitions)
return new_emitter_state, ()
# sample transitions
(transitions, random_key,) = emitter_state.replay_buffer.sample(
random_key=emitter_state.random_key,
sample_size=self._config.num_critic_training_steps
* self._config.batch_size,
)
# update the rewards - diversity rewards
state_descriptors = transitions.state_desc
diversity_rewards = self._score_novelty(archive, state_descriptors)
transitions = transitions.replace(rewards=diversity_rewards)
# reshape the transitions
transitions = jax.tree_util.tree_map(
lambda x: x.reshape(
(
self._config.num_critic_training_steps,
self._config.batch_size,
)
+ x.shape[1:]
),
transitions,
)
# Train critics and greedy actor
emitter_state, _ = jax.lax.scan(
scan_train_critics,
emitter_state,
(transitions),
length=self._config.num_critic_training_steps,
)
emitter_state = emitter_state.replace(archive=archive)
return emitter_state # type: ignore
@partial(jax.jit, static_argnames=("self",))
def _train_critics(
self, emitter_state: DiversityPGEmitterState, transitions: QDTransition
) -> DiversityPGEmitterState:
"""Apply one gradient step to critics and to the greedy actor
(contained in carry in training_state), then soft update target critics
and target greedy actor.
Those updates are very similar to those made in TD3.
Args:
emitter_state: actual emitter state
Returns:
New emitter state where the critic and the greedy actor have been
updated. Optimizer states have also been updated in the process.
"""
# Update Critic
(
critic_optimizer_state,
critic_params,
target_critic_params,
random_key,
) = self._update_critic(
critic_params=emitter_state.critic_params,
target_critic_params=emitter_state.target_critic_params,
target_actor_params=emitter_state.target_actor_params,
critic_optimizer_state=emitter_state.critic_optimizer_state,
transitions=transitions,
random_key=emitter_state.random_key,
)
# Update greedy policy
(policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond(
emitter_state.steps % self._config.policy_delay == 0,
lambda x: self._update_actor(*x),
lambda _: (
emitter_state.actor_opt_state,
emitter_state.actor_params,
emitter_state.target_actor_params,
),
operand=(
emitter_state.actor_params,
emitter_state.actor_opt_state,
emitter_state.target_actor_params,
emitter_state.critic_params,
transitions,
),
)
# Create new training state
new_emitter_state = emitter_state.replace(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
actor_params=actor_params,
actor_opt_state=policy_optimizer_state,
target_critic_params=target_critic_params,
target_actor_params=target_actor_params,
random_key=random_key,
steps=emitter_state.steps + 1,
replay_buffer=emitter_state.replay_buffer,
)
return new_emitter_state # type: ignore
@partial(jax.jit, static_argnames=("self",))
def _mutation_function_pg(
self,
policy_params: Genotype,
emitter_state: DiversityPGEmitterState,
) -> Genotype:
"""Apply pg mutation to a policy via multiple steps of gradient descent.
Args:
policy_params: a policy, supposed to be a differentiable neural
network.
emitter_state: the current state of the emitter, containing among others,
the replay buffer, the critic.
Returns:
the updated params of the neural network.
"""
# Define new policy optimizer state
policy_optimizer_state = self._policies_optimizer.init(policy_params)
def scan_train_policy(
carry: Tuple[DiversityPGEmitterState, Genotype, optax.OptState],
transitions: QDTransition,
) -> Tuple[Tuple[DiversityPGEmitterState, Genotype, optax.OptState], Any]:
emitter_state, policy_params, policy_optimizer_state = carry
(
new_emitter_state,
new_policy_params,
new_policy_optimizer_state,
) = self._train_policy(
emitter_state,
policy_params,
policy_optimizer_state,
transitions,
)
return (
new_emitter_state,
new_policy_params,
new_policy_optimizer_state,
), ()
# sample transitions
transitions, _random_key = emitter_state.replay_buffer.sample(
random_key=emitter_state.random_key,
sample_size=self._config.num_pg_training_steps * self._config.batch_size,
)
# update the rewards - diversity rewards
state_descriptors = transitions.state_desc
diversity_rewards = self._score_novelty(
emitter_state.archive, state_descriptors
)
transitions = transitions.replace(rewards=diversity_rewards)
# reshape the transitions
transitions = jax.tree_util.tree_map(
lambda x: x.reshape(
(
self._config.num_pg_training_steps,
self._config.batch_size,
)
+ x.shape[1:]
),
transitions,
)
(emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan(
scan_train_policy,
(emitter_state, policy_params, policy_optimizer_state),
(transitions),
length=self._config.num_pg_training_steps,
)
return policy_params
@partial(jax.jit, static_argnames=("self",))
def _train_policy(
self,
emitter_state: DiversityPGEmitterState,
policy_params: Params,
policy_optimizer_state: optax.OptState,
transitions: QDTransition,
) -> Tuple[DiversityPGEmitterState, Params, optax.OptState]:
"""Apply one gradient step to a policy (called policies_params).
Args:
emitter_state: current state of the emitter.
policy_params: parameters corresponding to the weights and bias of
the neural network that defines the policy.
Returns:
The new emitter state and new params of the NN.
"""
# update policy
policy_optimizer_state, policy_params = self._update_policy(
critic_params=emitter_state.critic_params,
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
transitions=transitions,
)
return emitter_state, policy_params, policy_optimizer_state
init(self, init_genotypes, random_key)
¶
Initializes the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/dpg_emitter.py
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[DiversityPGEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
# init elements of diversity emitter state with QualityEmitterState.init()
diversity_emitter_state, random_key = super().init(init_genotypes, random_key)
# store elements in a dictionary
attributes_dict = vars(diversity_emitter_state)
# init archive
archive = Archive.create(
acceptance_threshold=self._config.archive_acceptance_threshold,
state_descriptor_size=self._env.state_descriptor_length,
max_size=self._config.archive_max_size,
)
# init emitter state
emitter_state = DiversityPGEmitterState(
# retrieve all attributes from the QualityPGEmitterState
**attributes_dict,
# add the last element: archive
archive=archive,
)
return emitter_state, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
This function gives an opportunity to update the emitter state after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions from the scoring of the genotypes, and then the training of the critic/actor happens. Hence the params of critic/actor are updated, as well as their optimizer states.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/dpg_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: DiversityPGEmitterState,
repertoire: Optional[Repertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> DiversityPGEmitterState:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions
from the scoring of the genotypes, and then the training of the
critic/actor happens. Hence the params of critic/actor are updated,
as well as their optimizer states.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# get the transitions out of the dictionary
assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
transitions = extra_scores["transitions"]
# add transitions in the replay buffer
replay_buffer = emitter_state.replay_buffer.insert(transitions)
emitter_state = emitter_state.replace(replay_buffer=replay_buffer)
archive = emitter_state.archive.insert(transitions.state_desc)
def scan_train_critics(
carry: DiversityPGEmitterState, transitions: QDTransition
) -> Tuple[DiversityPGEmitterState, Any]:
emitter_state = carry
new_emitter_state = self._train_critics(emitter_state, transitions)
return new_emitter_state, ()
# sample transitions
(transitions, random_key,) = emitter_state.replay_buffer.sample(
random_key=emitter_state.random_key,
sample_size=self._config.num_critic_training_steps
* self._config.batch_size,
)
# update the rewards - diversity rewards
state_descriptors = transitions.state_desc
diversity_rewards = self._score_novelty(archive, state_descriptors)
transitions = transitions.replace(rewards=diversity_rewards)
# reshape the transitions
transitions = jax.tree_util.tree_map(
lambda x: x.reshape(
(
self._config.num_critic_training_steps,
self._config.batch_size,
)
+ x.shape[1:]
),
transitions,
)
# Train critics and greedy actor
emitter_state, _ = jax.lax.scan(
scan_train_critics,
emitter_state,
(transitions),
length=self._config.num_critic_training_steps,
)
emitter_state = emitter_state.replace(archive=archive)
return emitter_state # type: ignore
emitter
¶
EmitterState (PyTreeNode)
dataclass
¶
The state of an emitter. Emitters are used to suggest offspring when evolving a population of genotypes. To emit new genotypes, some emitters need to have a state, that carries useful informations, like running means, distribution parameters, critics, replay buffers etc...
The object emitter state is used to store them and is updated along the process.
Parameters: |
|
---|
Source code in qdax/core/emitters/emitter.py
class EmitterState(PyTreeNode):
"""The state of an emitter. Emitters are used to suggest offspring
when evolving a population of genotypes. To emit new genotypes, some
emitters need to have a state, that carries useful informations, like
running means, distribution parameters, critics, replay buffers etc...
The object emitter state is used to store them and is updated along
the process.
Args:
PyTreeNode: EmitterState base class inherits from PyTreeNode object
from flax.struct package. It help registering objects as Pytree
nodes automatically, and as the same benefits as classic Python
@dataclass decorator.
"""
pass
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
Emitter (ABC)
¶
Source code in qdax/core/emitters/emitter.py
class Emitter(ABC):
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
"""Initialises the state of the emitter. Some emitters do
not need a state, in which case, the value None can be
outputted.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state and a random key.
"""
return None, random_key
@abstractmethod
def emit(
self,
repertoire: Optional[Repertoire],
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Function used to emit a population of offspring by any possible
mean. New population can be sampled from a distribution or obtained
through mutations of individuals sampled from the repertoire.
Args:
repertoire: a repertoire of genotypes.
emitter_state: the state of the emitter.
random_key: a random key to handle random operations.
Returns:
A batch of offspring, a new random key.
"""
pass
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: Optional[EmitterState],
repertoire: Optional[Repertoire] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
As a matter of fact, many emitter states needs informations from
the evaluations of the genotypes in order to be updated, for instance:
- CMA emitter: to update the rank of the covariance matrix
- PGA emitter: to fill the replay buffer and update the critic/greedy
couple.
This function does not need to be overridden. By default, it output
the same emitter state.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
return emitter_state
@property
@abstractmethod
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
pass
@property
def use_all_data(self) -> bool:
"""Whether to use all data or not when used along other emitters.
Used when an emitter is used in a multi emitter setting.
Some emitter only the information from the genotypes they emitted when
they update their state (for instance, the CMA emitters); but other use data
from genotypes emitted by others (for instance, QualityPGEmitter and
DiversityPGEmitter). The meta emitters like MultiEmitter need to know which
data to give the sub emitter when udapting them. This property is used at
this moment.
Default behavior is to used only the data related to what was emitted.
Returns:
Whether to pass only the genotypes (and their evaluations) that the emitter
emitted when updating it or all the genotypes emitted by all the emitters.
"""
return False
batch_size: int
property
readonly
¶
Returns: |
|
---|
use_all_data: bool
property
readonly
¶
Whether to use all data or not when used along other emitters.
Used when an emitter is used in a multi emitter setting.
Some emitter only the information from the genotypes they emitted when they update their state (for instance, the CMA emitters); but other use data from genotypes emitted by others (for instance, QualityPGEmitter and DiversityPGEmitter). The meta emitters like MultiEmitter need to know which data to give the sub emitter when udapting them. This property is used at this moment.
Default behavior is to used only the data related to what was emitted.
Returns: |
|
---|
init(self, init_genotypes, random_key)
¶
Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/emitter.py
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
"""Initialises the state of the emitter. Some emitters do
not need a state, in which case, the value None can be
outputted.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state and a random key.
"""
return None, random_key
emit(self, repertoire, emitter_state, random_key)
¶
Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/emitter.py
@abstractmethod
def emit(
self,
repertoire: Optional[Repertoire],
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Function used to emit a population of offspring by any possible
mean. New population can be sampled from a distribution or obtained
through mutations of individuals sampled from the repertoire.
Args:
repertoire: a repertoire of genotypes.
emitter_state: the state of the emitter.
random_key: a random key to handle random operations.
Returns:
A batch of offspring, a new random key.
"""
pass
state_update(self, emitter_state, repertoire=None, genotypes=None, fitnesses=None, descriptors=None, extra_scores=None)
¶
This function gives an opportunity to update the emitter state after the genotypes have been scored.
As a matter of fact, many emitter states needs informations from the evaluations of the genotypes in order to be updated, for instance: - CMA emitter: to update the rank of the covariance matrix - PGA emitter: to fill the replay buffer and update the critic/greedy couple.
This function does not need to be overridden. By default, it output the same emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: Optional[EmitterState],
repertoire: Optional[Repertoire] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
As a matter of fact, many emitter states needs informations from
the evaluations of the genotypes in order to be updated, for instance:
- CMA emitter: to update the rank of the covariance matrix
- PGA emitter: to fill the replay buffer and update the critic/greedy
couple.
This function does not need to be overridden. By default, it output
the same emitter state.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
return emitter_state
mees_emitter
¶
Emitter and utils necessary to reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217
NoveltyArchive (PyTreeNode)
dataclass
¶
Novelty Archive used by the MAP-Elites-ES emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
class NoveltyArchive(flax.struct.PyTreeNode):
"""Novelty Archive used by the MAP-Elites-ES emitter.
Args:
archive: content of the archive
size: total size of the archive
position: current position in the archive
"""
archive: jnp.ndarray
size: int = flax.struct.field(pytree_node=False)
position: jnp.ndarray = flax.struct.field()
@classmethod
def init(
cls,
size: int,
num_descriptors: int,
) -> NoveltyArchive:
archive = jnp.zeros((size, num_descriptors))
return cls(archive=archive, size=size, position=jnp.array(0, dtype=int))
@jax.jit
def update(
self,
descriptor: Descriptor,
) -> NoveltyArchive:
"""Update the content of the novelty archive with newly generated descriptor.
Args:
descriptor: new descriptor generated by MAP-Elites-ES
Returns:
The updated NoveltyArchive
"""
new_archive = jax.lax.dynamic_update_slice_in_dim(
self.archive,
descriptor,
self.position,
axis=0,
)
new_position = (self.position + 1) % self.size
return NoveltyArchive(
archive=new_archive, size=self.size, position=new_position
)
@partial(jax.jit, static_argnames=("num_nearest_neighbors",))
def novelty(
self,
descriptors: Descriptor,
num_nearest_neighbors: int,
) -> jnp.ndarray:
"""Compute the novelty of the given descriptors as the average distance
to the k nearest neighbours in the archive.
Args:
descriptors: the descriptors to compute novelty for
num_nearest_neighbors: k used to compute the k-nearest-neighbours
Returns:
the novelty of each descriptor in descriptors.
"""
# Compute all distances with archive content
def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return jnp.sqrt(jnp.sum(jnp.square(x - y)))
distances = jax.vmap(
jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
)(descriptors, self.archive)
# Filter distance with empty slot of archive
indices = jnp.arange(0, self.size, step=1) < self.position + 1
distances = jax.vmap(lambda distance: jnp.where(indices, distance, jnp.inf))(
distances
)
# Find k nearest neighbours
_, indices = jax.lax.top_k(-distances, num_nearest_neighbors)
# Compute novelty as average distance with k neirest neirghbours
distances = jnp.where(distances == jnp.inf, jnp.nan, distances)
novelty = jnp.nanmean(jnp.take_along_axis(distances, indices, axis=1), axis=1)
return novelty
update(self, descriptor)
¶
Update the content of the novelty archive with newly generated descriptor.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@jax.jit
def update(
self,
descriptor: Descriptor,
) -> NoveltyArchive:
"""Update the content of the novelty archive with newly generated descriptor.
Args:
descriptor: new descriptor generated by MAP-Elites-ES
Returns:
The updated NoveltyArchive
"""
new_archive = jax.lax.dynamic_update_slice_in_dim(
self.archive,
descriptor,
self.position,
axis=0,
)
new_position = (self.position + 1) % self.size
return NoveltyArchive(
archive=new_archive, size=self.size, position=new_position
)
novelty(self, descriptors, num_nearest_neighbors)
¶
Compute the novelty of the given descriptors as the average distance to the k nearest neighbours in the archive.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(jax.jit, static_argnames=("num_nearest_neighbors",))
def novelty(
self,
descriptors: Descriptor,
num_nearest_neighbors: int,
) -> jnp.ndarray:
"""Compute the novelty of the given descriptors as the average distance
to the k nearest neighbours in the archive.
Args:
descriptors: the descriptors to compute novelty for
num_nearest_neighbors: k used to compute the k-nearest-neighbours
Returns:
the novelty of each descriptor in descriptors.
"""
# Compute all distances with archive content
def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return jnp.sqrt(jnp.sum(jnp.square(x - y)))
distances = jax.vmap(
jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
)(descriptors, self.archive)
# Filter distance with empty slot of archive
indices = jnp.arange(0, self.size, step=1) < self.position + 1
distances = jax.vmap(lambda distance: jnp.where(indices, distance, jnp.inf))(
distances
)
# Find k nearest neighbours
_, indices = jax.lax.top_k(-distances, num_nearest_neighbors)
# Compute novelty as average distance with k neirest neirghbours
distances = jnp.where(distances == jnp.inf, jnp.nan, distances)
novelty = jnp.nanmean(jnp.take_along_axis(distances, indices, axis=1), axis=1)
return novelty
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/mees_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
MEESConfig
dataclass
¶
Configuration for the MAP-Elites-ES emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@dataclass
class MEESConfig:
"""Configuration for the MAP-Elites-ES emitter.
Args:
sample_number: num of samples for gradient estimate
sample_sigma: std to sample the samples for gradient estimate
sample_mirror: if True, use mirroring sampling
sample_rank_norm: if True, use normalisation
num_optimizer_steps: frequency of archive-sampling
adam_optimizer: if True, use ADAM, if False, use SGD
learning_rate
l2_coefficient: coefficient for regularisation
novelty_nearest_neighbors
last_updated_size: number of last updated indiv used to
choose parents from repertoire
exploit_num_cell_sample: number of highest-performing cells
from which to choose parents, when using exploit
explore_num_cell_sample: number of most-novel cells from
which to choose parents, when using explore
use_explore: if False, use only fitness gradient
use_exploit: if False, use only novelty gradient
"""
sample_number: int = 1000
sample_sigma: float = 0.02
sample_mirror: bool = True
sample_rank_norm: bool = True
num_optimizer_steps: int = 10
adam_optimizer: bool = True
learning_rate: float = 0.01
l2_coefficient: float = 0.02
novelty_nearest_neighbors: int = 10
last_updated_size: int = 5
exploit_num_cell_sample: int = 2
explore_num_cell_sample: int = 5
use_explore: bool = True
use_exploit: bool = True
MEESEmitterState (EmitterState)
dataclass
¶
Emitter State for the MAP-Elites-ES emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitterState(EmitterState):
"""Emitter State for the MAP-Elites-ES emitter.
Args:
initial_optimizer_state: stored to re-initialise when sampling new parent
optimizer_state: current optimizer state
offspring: offspring generated through gradient estimate
generation_count: generation counter used to update the novelty archive
novelty_archive: used to compute novelty for explore
last_updated_genotypes: used to choose parents from repertoire
last_updated_fitnesses: used to choose parents from repertoire
last_updated_position: used to choose parents from repertoire
random_key: key to handle stochastic operations
"""
initial_optimizer_state: optax.OptState
optimizer_state: optax.OptState
offspring: Genotype
generation_count: int
novelty_archive: NoveltyArchive
last_updated_genotypes: Genotype
last_updated_fitnesses: Fitness
last_updated_position: jnp.ndarray
random_key: RNGKey
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/mees_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
MEESEmitter (Emitter)
¶
Emitter reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217
One can choose between the three variants by setting use_explore and use_exploit: ME-ES exploit-explore: use_exploit=True and use_explore=True Alternates between num_optimizer_steps of fitness gradients and num_optimizer_steps of novelty gradients, resample parent from the archive every num_optimizer_steps steps. ME-ES exploit: use_exploit=True and use_explore=False Only uses fitness gradient, no novelty gradients, but resample parent from the archive every num_optimizer_steps steps. ME-ES explore: use_exploit=False and use_explore=True Only uses novelty gradient, no fitness gradients, but resample parent from the archive every num_optimizer_steps steps.
Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitter(Emitter):
"""
Emitter reproducing the MAP-Elites-ES algorithm from
"Scaling MAP-Elites to Deep Neuroevolution" by Colas et al:
https://dl.acm.org/doi/pdf/10.1145/3377930.3390217
One can choose between the three variants by setting use_explore and use_exploit:
ME-ES exploit-explore: use_exploit=True and use_explore=True
Alternates between num_optimizer_steps of fitness gradients and
num_optimizer_steps of novelty gradients, resample parent from the archive
every num_optimizer_steps steps.
ME-ES exploit: use_exploit=True and use_explore=False
Only uses fitness gradient, no novelty gradients, but resample parent from
the archive every num_optimizer_steps steps.
ME-ES explore: use_exploit=False and use_explore=True
Only uses novelty gradient, no fitness gradients, but resample parent from
the archive every num_optimizer_steps steps.
"""
def __init__(
self,
config: MEESConfig,
total_generations: int,
scoring_fn: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
num_descriptors: int,
) -> None:
"""Initialise the MAP-Elites-ES emitter.
WARNING: total_generations is required to build the novelty archive.
Args:
config: algorithm config
scoring_fn: used to evaluate the samples for the gradient estimate.
total_generations: total number of generations for which the
emitter will run, allow to initialise the novelty archive.
num_descriptors: dimension of the descriptors, used to initialise
the empty novelty archive.
"""
self._config = config
self._scoring_fn = scoring_fn
self._total_generations = total_generations
self._num_descriptors = num_descriptors
# Initialise optimizer
if self._config.adam_optimizer:
self._optimizer = optax.adam(learning_rate=config.learning_rate)
else:
self._optimizer = optax.sgd(learning_rate=config.learning_rate)
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return 1
@partial(
jax.jit,
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
)
# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
# Create empty Novelty archive
if self._config.use_explore:
novelty_archive = NoveltyArchive.init(
self._total_generations, self._num_descriptors
)
else:
novelty_archive = NoveltyArchive.init(
self._config.novelty_nearest_neighbors, self._num_descriptors
)
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
)
return (
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=0,
random_key=random_key,
),
random_key,
)
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: MEESEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Return the offspring generated through gradient update.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state
random_key: a jax PRNG random key
Returns:
a new gradient offspring
a new jax PRNG key
"""
return emitter_state.offspring, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _sample_exploit(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Sample half of the time uniformly from the exploit_num_cell_sample
highest-performing cells of the repertoire and half of the time uniformly
from the exploit_num_cell_sample highest-performing cells among the
last updated cells.
Args:
emitter_state: current emitter_state
repertoire: the current repertoire
random_key: a jax PRNG random key
Returns:
samples: a genotype sampled in the repertoire
random_key: an updated jax PRNG random key
"""
def _sample(
random_key: RNGKey,
genotypes: Genotype,
fitnesses: Fitness,
) -> Tuple[Genotype, RNGKey]:
"""Sample uniformly from the 2 highest fitness cells."""
max_fitnesses, _ = jax.lax.top_k(
fitnesses, self._config.exploit_num_cell_sample
)
min_fitness = jnp.nanmin(
jnp.where(max_fitnesses > -jnp.inf, max_fitnesses, jnp.inf)
)
genotypes_empty = fitnesses < min_fitness
p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
genotypes,
)
return samples, random_key
random_key, subkey = jax.random.split(random_key)
# Sample p uniformly
p = jax.random.uniform(subkey)
# Depending on the value of p, use one of the two sampling options
repertoire_sample = partial(
_sample, genotypes=repertoire.genotypes, fitnesses=repertoire.fitnesses
)
last_updated_sample = partial(
_sample,
genotypes=emitter_state.last_updated_genotypes,
fitnesses=emitter_state.last_updated_fitnesses,
)
samples, random_key = jax.lax.cond(
p < 0.5,
repertoire_sample,
last_updated_sample,
random_key,
)
return samples, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _sample_explore(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Sample uniformly from the explore_num_cell_sample most-novel genotypes.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
random_key: a jax PRNG random key
Returns:
samples: a genotype sampled in the repertoire
random_key: an updated jax PRNG random key
"""
# Compute the novelty of all indivs in the archive
novelties = emitter_state.novelty_archive.novelty(
repertoire.descriptors, self._config.novelty_nearest_neighbors
)
novelties = jnp.where(repertoire.fitnesses > -jnp.inf, novelties, -jnp.inf)
# Sample uniformly for the explore_num_cell_sample most novel cells
max_novelties, _ = jax.lax.top_k(
novelties, self._config.explore_num_cell_sample
)
min_novelty = jnp.nanmin(
jnp.where(max_novelties > -jnp.inf, max_novelties, jnp.inf)
)
repertoire_empty = novelties < min_novelty
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
repertoire.genotypes,
)
return samples, random_key
@partial(
jax.jit,
static_argnames=("self", "scores_fn"),
)
def _es_emitter(
self,
parent: Genotype,
optimizer_state: optax.OptState,
random_key: RNGKey,
scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray],
) -> Tuple[Genotype, optax.OptState, RNGKey]:
"""Main es component, given a parent and a way to infer the score from
the fitnesses and descriptors fo its es-samples, return its
approximated-gradient-generated offspring.
Args:
parent: the considered parent.
scores_fn: a function to infer the score of its es-samples from
their fitness and descriptors.
random_key
Returns:
The approximated-gradients-generated offspring and a new random_key.
"""
random_key, subkey = jax.random.split(random_key)
# Sampling mirror noise
total_sample_number = self._config.sample_number
if self._config.sample_mirror:
sample_number = total_sample_number // 2
half_sample_noise = jax.tree_util.tree_map(
lambda x: jax.random.normal(
key=subkey,
shape=jnp.repeat(x, sample_number, axis=0).shape,
),
parent,
)
sample_noise = jax.tree_util.tree_map(
lambda x: jnp.concatenate(
[jnp.expand_dims(x, axis=1), jnp.expand_dims(-x, axis=1)], axis=1
).reshape(jnp.repeat(x, 2, axis=0).shape),
half_sample_noise,
)
gradient_noise = half_sample_noise
# Sampling non-mirror noise
else:
sample_number = total_sample_number
sample_noise = jax.tree_map(
lambda x: jax.random.normal(
key=subkey,
shape=jnp.repeat(x, sample_number, axis=0).shape,
),
parent,
)
gradient_noise = sample_noise
# Applying noise
samples = jax.tree_map(
lambda x: jnp.repeat(x, total_sample_number, axis=0),
parent,
)
samples = jax.tree_map(
lambda mean, noise: mean + self._config.sample_sigma * noise,
samples,
sample_noise,
)
# Evaluating samples
fitnesses, descriptors, extra_scores, random_key = self._scoring_fn(
samples, random_key
)
# Computing rank, with or without normalisation
scores = scores_fn(fitnesses, descriptors)
if self._config.sample_rank_norm:
ranking_indices = jnp.argsort(scores, axis=0)
ranks = jnp.argsort(ranking_indices, axis=0)
ranks = (ranks / (total_sample_number - 1)) - 0.5
else:
ranks = scores
# Reshaping rank to match shape of genotype_noise
if self._config.sample_mirror:
ranks = jnp.reshape(ranks, (sample_number, 2))
ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks)
ranks = jax.tree_map(
lambda x: jnp.reshape(
jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape
),
gradient_noise,
)
# Computing the gradients
gradient = jax.tree_map(
lambda noise, rank: jnp.multiply(noise, rank),
gradient_noise,
ranks,
)
gradient = jax.tree_map(
lambda x: jnp.reshape(x, (sample_number, -1)),
gradient,
)
gradient = jax.tree_map(
lambda g, p: jnp.reshape(
-jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma),
p.shape,
),
gradient,
parent,
)
# Adding regularisation
gradient = jax.tree_map(
lambda g, p: g + self._config.l2_coefficient * p,
gradient,
parent,
)
# Applying gradients
(offspring_update, optimizer_state) = self._optimizer.update(
gradient, optimizer_state
)
offspring = optax.apply_updates(parent, offspring_update)
return offspring, optimizer_state, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def _buffers_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
) -> MEESEmitterState:
"""Update the different buffers and archives in the emitter
state to generate the offspring for the next generation.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
Returns:
The modified emitter state.
"""
# Updating novelty archive
novelty_archive = emitter_state.novelty_archive.update(descriptors)
# Check if genotype from previous iteration has been added to the grid
indice = get_cells_indices(descriptors, repertoire.centroids)
added_genotype = jnp.all(
jnp.asarray(
jax.tree_util.tree_leaves(
jax.tree_util.tree_map(
lambda new_gen, rep_gen: jnp.all(
jnp.equal(
jnp.ravel(new_gen), jnp.ravel(rep_gen.at[indice].get())
),
axis=0,
),
genotypes,
repertoire.genotypes,
),
)
),
axis=0,
)
# Update last_updated buffers
last_updated_position = jnp.where(
added_genotype,
emitter_state.last_updated_position,
self._config.last_updated_size + 1,
)
last_updated_fitnesses = emitter_state.last_updated_fitnesses
last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set(
fitnesses[0]
)
last_updated_genotypes = jax.tree_map(
lambda last_gen, gen: last_gen.at[
jnp.expand_dims(last_updated_position, axis=0)
].set(gen),
emitter_state.last_updated_genotypes,
genotypes,
)
last_updated_position = (
emitter_state.last_updated_position + added_genotype
) % self._config.last_updated_size
# Return new emitter_state
return emitter_state.replace( # type: ignore
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=last_updated_position,
)
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> MEESEmitterState:
"""Generate the gradient offspring for the next emitter call. Also
update the novelty archive and generation count from current call.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
"ERROR: MAP-Elites-ES generates 1 offspring per generation, "
+ "batch_size should be 1, the inputed batch has size:"
+ str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
)
# Update all the buffers and archives of the emitter_state
emitter_state = self._buffers_update(
emitter_state, repertoire, genotypes, fitnesses, descriptors
)
# Use new or previous parents and exploitation or exploration
generation_count = emitter_state.generation_count
sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
use_exploration = (
self._config.use_explore and not self._config.use_exploit
) or (
self._config.use_explore
and self._config.use_exploit
and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
)
# Select parent and optimizer_state
parent, random_key = jax.lax.cond(
sample_new_parent,
lambda emitter_state, repertoire, random_key: jax.lax.cond(
use_exploration,
self._sample_explore,
self._sample_exploit,
emitter_state,
repertoire,
random_key,
),
lambda emitter_state, repertoire, random_key: (
emitter_state.offspring,
random_key,
),
emitter_state,
repertoire,
emitter_state.random_key,
)
optimizer_state = jax.lax.cond(
sample_new_parent,
lambda _unused: emitter_state.initial_optimizer_state,
lambda _unused: emitter_state.optimizer_state,
(),
)
# Define scores for es process
def exploration_exploitation_scores(
fitnesses: Fitness, descriptors: Descriptor
) -> jnp.ndarray:
scores = jax.lax.cond(
use_exploration,
lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
descriptors, self._config.novelty_nearest_neighbors
),
lambda fitnesses, descriptors: fitnesses,
fitnesses,
descriptors,
)
return scores
# Run es process
offspring, optimizer_state, random_key = self._es_emitter(
parent=parent,
optimizer_state=optimizer_state,
random_key=random_key,
scores_fn=exploration_exploitation_scores,
)
return emitter_state.replace( # type: ignore
optimizer_state=optimizer_state,
offspring=offspring,
generation_count=generation_count + 1,
random_key=random_key,
)
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, config, total_generations, scoring_fn, num_descriptors)
special
¶
Initialise the MAP-Elites-ES emitter. WARNING: total_generations is required to build the novelty archive.
Parameters: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
def __init__(
self,
config: MEESConfig,
total_generations: int,
scoring_fn: Callable[
[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
],
num_descriptors: int,
) -> None:
"""Initialise the MAP-Elites-ES emitter.
WARNING: total_generations is required to build the novelty archive.
Args:
config: algorithm config
scoring_fn: used to evaluate the samples for the gradient estimate.
total_generations: total number of generations for which the
emitter will run, allow to initialise the novelty archive.
num_descriptors: dimension of the descriptors, used to initialise
the empty novelty archive.
"""
self._config = config
self._scoring_fn = scoring_fn
self._total_generations = total_generations
self._num_descriptors = num_descriptors
# Initialise optimizer
if self._config.adam_optimizer:
self._optimizer = optax.adam(learning_rate=config.learning_rate)
else:
self._optimizer = optax.sgd(learning_rate=config.learning_rate)
init(self, init_genotypes, random_key)
¶
Initializes the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the MEESEmitter, a new random key.
"""
# Initialisation requires one initial genotype
if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
init_genotypes = jax.tree_util.tree_map(
lambda x: x[0],
init_genotypes,
)
# Initialise optimizer
initial_optimizer_state = self._optimizer.init(init_genotypes)
# Create empty Novelty archive
if self._config.use_explore:
novelty_archive = NoveltyArchive.init(
self._total_generations, self._num_descriptors
)
else:
novelty_archive = NoveltyArchive.init(
self._config.novelty_nearest_neighbors, self._num_descriptors
)
# Create empty updated genotypes and fitness
last_updated_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
init_genotypes,
)
last_updated_fitnesses = -jnp.inf * jnp.ones(
shape=self._config.last_updated_size
)
return (
MEESEmitterState(
initial_optimizer_state=initial_optimizer_state,
optimizer_state=initial_optimizer_state,
offspring=init_genotypes,
generation_count=0,
novelty_archive=novelty_archive,
last_updated_genotypes=last_updated_genotypes,
last_updated_fitnesses=last_updated_fitnesses,
last_updated_position=0,
random_key=random_key,
),
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
Return the offspring generated through gradient update.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: MEESEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Return the offspring generated through gradient update.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state
random_key: a jax PRNG random key
Returns:
a new gradient offspring
a new jax PRNG key
"""
return emitter_state.offspring, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
Generate the gradient offspring for the next emitter call. Also update the novelty archive and generation count from current call.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mees_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: MEESEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> MEESEmitterState:
"""Generate the gradient offspring for the next emitter call. Also
update the novelty archive and generation count from current call.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
"ERROR: MAP-Elites-ES generates 1 offspring per generation, "
+ "batch_size should be 1, the inputed batch has size:"
+ str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
)
# Update all the buffers and archives of the emitter_state
emitter_state = self._buffers_update(
emitter_state, repertoire, genotypes, fitnesses, descriptors
)
# Use new or previous parents and exploitation or exploration
generation_count = emitter_state.generation_count
sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
use_exploration = (
self._config.use_explore and not self._config.use_exploit
) or (
self._config.use_explore
and self._config.use_exploit
and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
)
# Select parent and optimizer_state
parent, random_key = jax.lax.cond(
sample_new_parent,
lambda emitter_state, repertoire, random_key: jax.lax.cond(
use_exploration,
self._sample_explore,
self._sample_exploit,
emitter_state,
repertoire,
random_key,
),
lambda emitter_state, repertoire, random_key: (
emitter_state.offspring,
random_key,
),
emitter_state,
repertoire,
emitter_state.random_key,
)
optimizer_state = jax.lax.cond(
sample_new_parent,
lambda _unused: emitter_state.initial_optimizer_state,
lambda _unused: emitter_state.optimizer_state,
(),
)
# Define scores for es process
def exploration_exploitation_scores(
fitnesses: Fitness, descriptors: Descriptor
) -> jnp.ndarray:
scores = jax.lax.cond(
use_exploration,
lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
descriptors, self._config.novelty_nearest_neighbors
),
lambda fitnesses, descriptors: fitnesses,
fitnesses,
descriptors,
)
return scores
# Run es process
offspring, optimizer_state, random_key = self._es_emitter(
parent=parent,
optimizer_state=optimizer_state,
random_key=random_key,
scores_fn=exploration_exploitation_scores,
)
return emitter_state.replace( # type: ignore
optimizer_state=optimizer_state,
offspring=offspring,
generation_count=generation_count + 1,
random_key=random_key,
)
multi_emitter
¶
MultiEmitterState (EmitterState)
dataclass
¶
State of an emitter than use multiple emitters in a parallel manner.
WARNING: this is not the emitter state of Multi-Emitter MAP-Elites.
Parameters: |
|
---|
Source code in qdax/core/emitters/multi_emitter.py
class MultiEmitterState(EmitterState):
"""State of an emitter than use multiple emitters in a parallel manner.
WARNING: this is not the emitter state of Multi-Emitter MAP-Elites.
Args:
emitter_states: a tuple of emitter states
"""
emitter_states: Tuple[EmitterState, ...]
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/multi_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
MultiEmitter (Emitter)
¶
Emitter that mixes several emitters in parallel.
WARNING: this is not the emitter of Multi-Emitter MAP-Elites.
Source code in qdax/core/emitters/multi_emitter.py
class MultiEmitter(Emitter):
"""Emitter that mixes several emitters in parallel.
WARNING: this is not the emitter of Multi-Emitter MAP-Elites.
"""
def __init__(
self,
emitters: Tuple[Emitter, ...],
):
self.emitters = emitters
indexes_separation_batches = self.get_indexes_separation_batches(emitters)
self.indexes_start_batches = indexes_separation_batches[:-1]
self.indexes_end_batches = indexes_separation_batches[1:]
@staticmethod
def get_indexes_separation_batches(
emitters: Tuple[Emitter, ...]
) -> Tuple[int, ...]:
"""Get the indexes of the separation between batches of each emitter.
Args:
emitters: the emitters
Returns:
a tuple of tuples of indexes
"""
indexes_separation_batches = np.cumsum(
[0] + [emitter.batch_size for emitter in emitters]
)
return tuple(indexes_separation_batches)
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
"""
Initialize the state of the emitter.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state and a random key.
"""
# prepare keys for each emitter
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, len(self.emitters))
# init all emitter states - gather them
emitter_states = []
for emitter, subkey_emitter in zip(self.emitters, subkeys):
emitter_state, _ = emitter.init(init_genotypes, subkey_emitter)
emitter_states.append(emitter_state)
return MultiEmitterState(tuple(emitter_states)), random_key
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[Repertoire],
emitter_state: Optional[MultiEmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Emit new population. Use all the sub emitters to emit subpopulation
and gather them.
Args:
repertoire: a repertoire of genotypes.
emitter_state: the current state of the emitter.
random_key: key for random operations.
Returns:
Offsprings and a new random key.
"""
assert emitter_state is not None
assert len(emitter_state.emitter_states) == len(self.emitters)
# prepare subkeys for each sub emitter
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, len(self.emitters))
# emit from all emitters and gather offsprings
all_offsprings = []
for emitter, sub_emitter_state, subkey_emitter in zip(
self.emitters,
emitter_state.emitter_states,
subkeys,
):
genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter)
batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0]
assert batch_size == emitter.batch_size
all_offsprings.append(genotype)
# concatenate offsprings together
offsprings = jax.tree_util.tree_map(
lambda *x: jnp.concatenate(x, axis=0), *all_offsprings
)
return offsprings, random_key
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: Optional[MultiEmitterState],
repertoire: Optional[Repertoire] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[MultiEmitterState]:
"""Update emitter state by updating all sub emitter states.
Args:
emitter_state: current emitter state.
repertoire: current repertoire of genotypes. Defaults to None.
genotypes: proposed genotypes. Defaults to None.
fitnesses: associated fitnesses. Defaults to None.
descriptors: associated descriptors. Defaults to None.
extra_scores: associated extra_scores. Defaults to None.
Returns:
The updated global emitter state.
"""
if emitter_state is None:
return None
# update all the sub emitter states
emitter_states = []
def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree:
return jax.tree_util.tree_map(lambda x: x[start:end], pytree)
for emitter, sub_emitter_state, index_start, index_end in zip(
self.emitters,
emitter_state.emitter_states,
self.indexes_start_batches,
self.indexes_end_batches,
):
# update with all genotypes, fitnesses, etc...
if emitter.use_all_data:
new_sub_emitter_state = emitter.state_update(
sub_emitter_state,
repertoire,
genotypes,
fitnesses,
descriptors,
extra_scores,
)
emitter_states.append(new_sub_emitter_state)
# update only with the data of the emitted genotypes
else:
# extract relevant data
sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree_util.tree_map(
partial(_get_sub_pytree, start=index_start, end=index_end),
(
genotypes,
fitnesses,
descriptors,
extra_scores,
),
)
# update only with the relevant data
new_sub_emitter_state = emitter.state_update(
sub_emitter_state,
repertoire,
sub_gen,
sub_fit,
sub_desc,
sub_extra_scores,
)
emitter_states.append(new_sub_emitter_state)
# return the update global emitter state
return MultiEmitterState(tuple(emitter_states))
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return sum(emitter.batch_size for emitter in self.emitters)
batch_size: int
property
readonly
¶
Returns: |
|
---|
get_indexes_separation_batches(emitters)
staticmethod
¶
Get the indexes of the separation between batches of each emitter.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/multi_emitter.py
@staticmethod
def get_indexes_separation_batches(
emitters: Tuple[Emitter, ...]
) -> Tuple[int, ...]:
"""Get the indexes of the separation between batches of each emitter.
Args:
emitters: the emitters
Returns:
a tuple of tuples of indexes
"""
indexes_separation_batches = np.cumsum(
[0] + [emitter.batch_size for emitter in emitters]
)
return tuple(indexes_separation_batches)
init(self, init_genotypes, random_key)
¶
Initialize the state of the emitter.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/multi_emitter.py
def init(
self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
"""
Initialize the state of the emitter.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state and a random key.
"""
# prepare keys for each emitter
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, len(self.emitters))
# init all emitter states - gather them
emitter_states = []
for emitter, subkey_emitter in zip(self.emitters, subkeys):
emitter_state, _ = emitter.init(init_genotypes, subkey_emitter)
emitter_states.append(emitter_state)
return MultiEmitterState(tuple(emitter_states)), random_key
emit(self, repertoire, emitter_state, random_key)
¶
Emit new population. Use all the sub emitters to emit subpopulation and gather them.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/multi_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
self,
repertoire: Optional[Repertoire],
emitter_state: Optional[MultiEmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Emit new population. Use all the sub emitters to emit subpopulation
and gather them.
Args:
repertoire: a repertoire of genotypes.
emitter_state: the current state of the emitter.
random_key: key for random operations.
Returns:
Offsprings and a new random key.
"""
assert emitter_state is not None
assert len(emitter_state.emitter_states) == len(self.emitters)
# prepare subkeys for each sub emitter
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, len(self.emitters))
# emit from all emitters and gather offsprings
all_offsprings = []
for emitter, sub_emitter_state, subkey_emitter in zip(
self.emitters,
emitter_state.emitter_states,
subkeys,
):
genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter)
batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0]
assert batch_size == emitter.batch_size
all_offsprings.append(genotype)
# concatenate offsprings together
offsprings = jax.tree_util.tree_map(
lambda *x: jnp.concatenate(x, axis=0), *all_offsprings
)
return offsprings, random_key
state_update(self, emitter_state, repertoire=None, genotypes=None, fitnesses=None, descriptors=None, extra_scores=None)
¶
Update emitter state by updating all sub emitter states.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/multi_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: Optional[MultiEmitterState],
repertoire: Optional[Repertoire] = None,
genotypes: Optional[Genotype] = None,
fitnesses: Optional[Fitness] = None,
descriptors: Optional[Descriptor] = None,
extra_scores: Optional[ExtraScores] = None,
) -> Optional[MultiEmitterState]:
"""Update emitter state by updating all sub emitter states.
Args:
emitter_state: current emitter state.
repertoire: current repertoire of genotypes. Defaults to None.
genotypes: proposed genotypes. Defaults to None.
fitnesses: associated fitnesses. Defaults to None.
descriptors: associated descriptors. Defaults to None.
extra_scores: associated extra_scores. Defaults to None.
Returns:
The updated global emitter state.
"""
if emitter_state is None:
return None
# update all the sub emitter states
emitter_states = []
def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree:
return jax.tree_util.tree_map(lambda x: x[start:end], pytree)
for emitter, sub_emitter_state, index_start, index_end in zip(
self.emitters,
emitter_state.emitter_states,
self.indexes_start_batches,
self.indexes_end_batches,
):
# update with all genotypes, fitnesses, etc...
if emitter.use_all_data:
new_sub_emitter_state = emitter.state_update(
sub_emitter_state,
repertoire,
genotypes,
fitnesses,
descriptors,
extra_scores,
)
emitter_states.append(new_sub_emitter_state)
# update only with the data of the emitted genotypes
else:
# extract relevant data
sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree_util.tree_map(
partial(_get_sub_pytree, start=index_start, end=index_end),
(
genotypes,
fitnesses,
descriptors,
extra_scores,
),
)
# update only with the relevant data
new_sub_emitter_state = emitter.state_update(
sub_emitter_state,
repertoire,
sub_gen,
sub_fit,
sub_desc,
sub_extra_scores,
)
emitter_states.append(new_sub_emitter_state)
# return the update global emitter state
return MultiEmitterState(tuple(emitter_states))
mutation_operators
¶
File defining mutation and crossover functions.
polynomial_mutation(x, random_key, proportion_to_mutate, eta, minval, maxval)
¶
Polynomial mutation over several genotypes
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mutation_operators.py
def polynomial_mutation(
x: Genotype,
random_key: RNGKey,
proportion_to_mutate: float,
eta: float,
minval: float,
maxval: float,
) -> Tuple[Genotype, RNGKey]:
"""
Polynomial mutation over several genotypes
Parameters:
x: array of genotypes to transform (real values only)
random_key: RNG key for reproducibility.
Assumed to be of shape (batch_size, genotype_dim)
proportion_to_mutate (float): proportion of variables to mutate in
each genotype (must be in [0, 1]).
eta: scaling parameter, the larger the more spread the new
values will be.
minval: minimum value to clip the genotypes.
maxval: maximum value to clip the genotypes.
Returns:
New genotypes - same shape as input and a new RNG key
"""
random_key, subkey = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x)[0].shape[0]
mutation_key = jax.random.split(subkey, num=batch_size)
mutation_fn = partial(
_polynomial_mutation,
proportion_to_mutate=proportion_to_mutate,
eta=eta,
minval=minval,
maxval=maxval,
)
mutation_fn = jax.vmap(mutation_fn)
x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x)
return x, random_key
polynomial_crossover(x1, x2, random_key, proportion_var_to_change)
¶
Crossover over a set of pairs of genotypes.
Batched version of _simple_crossover_function x1 and x2 should have the same shape In this function we assume x1 shape and x2 shape to be (batch_size, genotype_dim)
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/mutation_operators.py
def polynomial_crossover(
x1: Genotype,
x2: Genotype,
random_key: RNGKey,
proportion_var_to_change: float,
) -> Tuple[Genotype, RNGKey]:
"""
Crossover over a set of pairs of genotypes.
Batched version of _simple_crossover_function
x1 and x2 should have the same shape
In this function we assume x1 shape and x2 shape to be
(batch_size, genotype_dim)
Parameters:
x1: first batch of genotypes
x2: second batch of genotypes
random_key: RNG key for reproducibility
proportion_var_to_change: proportion of variables to exchange
between genotypes (must be [0, 1])
Returns:
New genotypes and a new RNG key
"""
random_key, subkey = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0]
crossover_keys = jax.random.split(subkey, num=batch_size)
crossover_fn = partial(
_polynomial_crossover,
proportion_var_to_change=proportion_var_to_change,
)
crossover_fn = jax.vmap(crossover_fn)
# TODO: check that key usage is correct
x = jax.tree_util.tree_map(
lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2
)
return x, random_key
isoline_variation(x1, x2, random_key, iso_sigma, line_sigma, minval=None, maxval=None)
¶
Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes
Parameters: |
|
---|
Returns: |
|
---|
[1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and Evolutionary Computation Conference. 2018.
Source code in qdax/core/emitters/mutation_operators.py
def isoline_variation(
x1: Genotype,
x2: Genotype,
random_key: RNGKey,
iso_sigma: float,
line_sigma: float,
minval: Optional[float] = None,
maxval: Optional[float] = None,
) -> Tuple[Genotype, RNGKey]:
"""
Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes
Parameters:
x1 (Genotypes): first batch of genotypes
x2 (Genotypes): second batch of genotypes
random_key (RNGKey): RNG key for reproducibility
iso_sigma (float): spread parameter (noise)
line_sigma (float): line parameter (direction of the new genotype)
minval (float, Optional): minimum value to clip the genotypes
maxval (float, Optional): maximum value to clip the genotypes
Returns:
x (Genotypes): new genotypes
random_key (RNGKey): new RNG key
[1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite
hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and
Evolutionary Computation Conference. 2018.
"""
# Computing line_noise
random_key, key_line_noise = jax.random.split(random_key)
batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0]
line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma
def _variation_fn(
x1: jnp.ndarray, x2: jnp.ndarray, random_key: RNGKey
) -> jnp.ndarray:
iso_noise = jax.random.normal(random_key, shape=x1.shape) * iso_sigma
x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1), line_noise)
# Back in bounds if necessary (floating point issues)
if (minval is not None) or (maxval is not None):
x = jnp.clip(x, minval, maxval)
return x
# create a tree with random keys
nb_leaves = len(jax.tree_util.tree_leaves(x1))
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=nb_leaves)
keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys)
# apply isolinedd to each branch of the tree
x = jax.tree_util.tree_map(
lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree
)
return x, random_key
omg_mega_emitter
¶
OMGMEGAEmitterState (EmitterState)
dataclass
¶
Emitter state for the CMA-MEGA emitter.
Parameters: |
|
---|
Source code in qdax/core/emitters/omg_mega_emitter.py
class OMGMEGAEmitterState(EmitterState):
"""
Emitter state for the CMA-MEGA emitter.
Args:
gradients_repertoire: MapElites repertoire containing the gradients
of the indivuals.
"""
gradients_repertoire: MapElitesRepertoire
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/omg_mega_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
OMGMEGAEmitter (Emitter)
¶
Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by Fontaine et al.
NOTE: in order to implement this emitter while staying in the MAPElites framework, we had to make two temporary design choices: - in the emit function, we use the same random key to sample from the genotypes and gradients repertoire, in order to get the gradients that correspond to the right genotypes. Although acceptable, this is definitely not the best coding practice and we would prefer to get rid of this in a future version. A solution that we are discussing with the development team is to decompose the sampling function of the repertoire into two phases: one sampling the indices to be sampled, the other one retrieving the corresponding elements. This would enable to reuse the indices instead of doing this double sampling. - in the state_update, we have to insert the gradients in the gradients repertoire in the same way the individuals were inserted. Once again, this is slightly unoptimal because the same addition mecanism has to be computed two times. One solution that we are discussing and that is very similar to the first solution discussed above, would be to decompose the addition mecanism in two phases: one outputing the indices at which individuals will be added, and then the actual insertion step. This would enable to re-use the same indices to add the gradients instead of having to recompute them.
The two design choices seem acceptable and enable to have OMG MEGA compatible with the current implementation of the MAPElites and MAPElitesRepertoire classes.
Our suggested solutions seem quite simple and are likely to be useful for other variants implementation. They will be further discussed with the development team and potentially added in a future version of the package.
Source code in qdax/core/emitters/omg_mega_emitter.py
class OMGMEGAEmitter(Emitter):
"""
Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by
Fontaine et al.
NOTE: in order to implement this emitter while staying in the MAPElites
framework, we had to make two temporary design choices:
- in the emit function, we use the same random key to sample from the
genotypes and gradients repertoire, in order to get the gradients that
correspond to the right genotypes. Although acceptable, this is definitely
not the best coding practice and we would prefer to get rid of this in a
future version. A solution that we are discussing with the development team
is to decompose the sampling function of the repertoire into two phases: one
sampling the indices to be sampled, the other one retrieving the corresponding
elements. This would enable to reuse the indices instead of doing this double
sampling.
- in the state_update, we have to insert the gradients in the gradients
repertoire in the same way the individuals were inserted. Once again, this is
slightly unoptimal because the same addition mecanism has to be computed two
times. One solution that we are discussing and that is very similar to the first
solution discussed above, would be to decompose the addition mecanism in two
phases: one outputing the indices at which individuals will be added, and then
the actual insertion step. This would enable to re-use the same indices to add
the gradients instead of having to recompute them.
The two design choices seem acceptable and enable to have OMG MEGA compatible
with the current implementation of the MAPElites and MAPElitesRepertoire classes.
Our suggested solutions seem quite simple and are likely to be useful for other
variants implementation. They will be further discussed with the development team
and potentially added in a future version of the package.
"""
def __init__(
self,
batch_size: int,
sigma_g: float,
num_descriptors: int,
centroids: Centroid,
):
"""Creates an instance of the OMGMEGAEmitter class.
Args:
batch_size: number of solutions sampled at each iteration
sigma_g: CAUTION - square of the standard deviation for the coefficients.
This notation can be misleading as, although it's called sigma, it
refers to the variance and not the standard deviation.
num_descriptors: number of descriptors
centroids: centroids used to create the repertoire of solutions.
This will be used to create the repertoire of gradients.
"""
# set the mean of the coeff distribution to zero
self._mu = jnp.zeros(num_descriptors + 1)
# set the cov matrix to sigma * I
self._sigma = jnp.eye(num_descriptors + 1) * sigma_g
# define other parameters of the distribution
self._batch_size = batch_size
self._centroids = centroids
self._num_descriptors = num_descriptors
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[OMGMEGAEmitterState, RNGKey]:
"""Initialises the state of the emitter. Creates an empty repertoire
that will later contain the gradients of the individuals.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state.
"""
# retrieve one genotype from the population
first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
# add a dimension of size num descriptors + 1
gradient_genotype = jax.tree_util.tree_map(
lambda x: jnp.repeat(
jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
),
first_genotype,
)
# create the gradients repertoire
gradients_repertoire = MapElitesRepertoire.init_default(
genotype=gradient_genotype, centroids=self._centroids
)
return (
OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire),
random_key,
)
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: OMGMEGAEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
OMG emitter function that samples elements in the repertoire and does a gradient
update with random coefficients to create new candidates.
Args:
repertoire: current repertoire
emitter_state: current emitter state, contains the gradients
random_key: random key
Returns:
new_genotypes: new candidates to be added to the grid
random_key: updated random key
"""
# sample genotypes
(
genotypes,
_,
) = repertoire.sample(random_key, num_samples=self._batch_size)
# sample gradients - use the same random key for sampling
# See class docstrings for discussion about this choice
gradients, random_key = emitter_state.gradients_repertoire.sample(
random_key, num_samples=self._batch_size
)
fitness_gradients = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
)
descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients)
# Normalize the gradients
norm_fitness_gradients = jnp.linalg.norm(
fitness_gradients, axis=1, keepdims=True
)
fitness_gradients = fitness_gradients / norm_fitness_gradients
norm_descriptors_gradients = jnp.linalg.norm(
descriptors_gradients, axis=1, keepdims=True
)
descriptors_gradients = descriptors_gradients / norm_descriptors_gradients
# Draw random coefficients
random_key, subkey = jax.random.split(random_key)
coeffs = jax.random.multivariate_normal(
subkey,
shape=(self._batch_size,),
mean=self._mu,
cov=self._sigma,
)
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
grads = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=-1),
fitness_gradients,
descriptors_gradients,
)
update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)
# update the genotypes
new_genotypes = jax.tree_util.tree_map(
lambda x, y: x + y, genotypes, update_grad
)
return new_genotypes, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: OMGMEGAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
"""Update the gradients repertoire to have the right gradients.
NOTE: see discussion in the class docstrings to see how this could
be improved.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
# get gradients out of the extra scores
assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
gradients = extra_scores["gradients"]
# update the gradients repertoire
gradients_repertoire = emitter_state.gradients_repertoire.add(
gradients,
descriptors,
fitnesses,
extra_scores,
)
return emitter_state.replace( # type: ignore
gradients_repertoire=gradients_repertoire
)
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._batch_size
batch_size: int
property
readonly
¶
Returns: |
|
---|
__init__(self, batch_size, sigma_g, num_descriptors, centroids)
special
¶
Creates an instance of the OMGMEGAEmitter class.
Parameters: |
|
---|
Source code in qdax/core/emitters/omg_mega_emitter.py
def __init__(
self,
batch_size: int,
sigma_g: float,
num_descriptors: int,
centroids: Centroid,
):
"""Creates an instance of the OMGMEGAEmitter class.
Args:
batch_size: number of solutions sampled at each iteration
sigma_g: CAUTION - square of the standard deviation for the coefficients.
This notation can be misleading as, although it's called sigma, it
refers to the variance and not the standard deviation.
num_descriptors: number of descriptors
centroids: centroids used to create the repertoire of solutions.
This will be used to create the repertoire of gradients.
"""
# set the mean of the coeff distribution to zero
self._mu = jnp.zeros(num_descriptors + 1)
# set the cov matrix to sigma * I
self._sigma = jnp.eye(num_descriptors + 1) * sigma_g
# define other parameters of the distribution
self._batch_size = batch_size
self._centroids = centroids
self._num_descriptors = num_descriptors
init(self, init_genotypes, random_key)
¶
Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/omg_mega_emitter.py
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[OMGMEGAEmitterState, RNGKey]:
"""Initialises the state of the emitter. Creates an empty repertoire
that will later contain the gradients of the individuals.
Args:
init_genotypes: The genotypes of the initial population.
random_key: a random key to handle stochastic operations.
Returns:
The initial emitter state.
"""
# retrieve one genotype from the population
first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
# add a dimension of size num descriptors + 1
gradient_genotype = jax.tree_util.tree_map(
lambda x: jnp.repeat(
jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
),
first_genotype,
)
# create the gradients repertoire
gradients_repertoire = MapElitesRepertoire.init_default(
genotype=gradient_genotype, centroids=self._centroids
)
return (
OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire),
random_key,
)
emit(self, repertoire, emitter_state, random_key)
¶
OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/omg_mega_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: MapElitesRepertoire,
emitter_state: OMGMEGAEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
OMG emitter function that samples elements in the repertoire and does a gradient
update with random coefficients to create new candidates.
Args:
repertoire: current repertoire
emitter_state: current emitter state, contains the gradients
random_key: random key
Returns:
new_genotypes: new candidates to be added to the grid
random_key: updated random key
"""
# sample genotypes
(
genotypes,
_,
) = repertoire.sample(random_key, num_samples=self._batch_size)
# sample gradients - use the same random key for sampling
# See class docstrings for discussion about this choice
gradients, random_key = emitter_state.gradients_repertoire.sample(
random_key, num_samples=self._batch_size
)
fitness_gradients = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
)
descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients)
# Normalize the gradients
norm_fitness_gradients = jnp.linalg.norm(
fitness_gradients, axis=1, keepdims=True
)
fitness_gradients = fitness_gradients / norm_fitness_gradients
norm_descriptors_gradients = jnp.linalg.norm(
descriptors_gradients, axis=1, keepdims=True
)
descriptors_gradients = descriptors_gradients / norm_descriptors_gradients
# Draw random coefficients
random_key, subkey = jax.random.split(random_key)
coeffs = jax.random.multivariate_normal(
subkey,
shape=(self._batch_size,),
mean=self._mu,
cov=self._sigma,
)
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
grads = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=-1),
fitness_gradients,
descriptors_gradients,
)
update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)
# update the genotypes
new_genotypes = jax.tree_util.tree_map(
lambda x, y: x + y, genotypes, update_grad
)
return new_genotypes, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
Update the gradients repertoire to have the right gradients.
NOTE: see discussion in the class docstrings to see how this could be improved.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/omg_mega_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def state_update(
self,
emitter_state: OMGMEGAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
"""Update the gradients repertoire to have the right gradients.
NOTE: see discussion in the class docstrings to see how this could
be improved.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
# get gradients out of the extra scores
assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
gradients = extra_scores["gradients"]
# update the gradients repertoire
gradients_repertoire = emitter_state.gradients_repertoire.add(
gradients,
descriptors,
fitnesses,
extra_scores,
)
return emitter_state.replace( # type: ignore
gradients_repertoire=gradients_repertoire
)
pbt_me_emitter
¶
PBTEmitterState (EmitterState)
dataclass
¶
PBT emitter state contains the replay buffers that will be used by the population as well as the population agents training states and their starting environment state.
Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitterState(EmitterState):
"""
PBT emitter state contains the replay buffers that will be used by the population as
well as the population agents training states and their starting environment state.
"""
replay_buffers: ReplayBuffer
env_states: EnvState
training_states: PBTTrainingState
random_key: RNGKey
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/pbt_me_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
PBTEmitterConfig (PyTreeNode)
dataclass
¶
Config for the PBT-ME emitter. This mainly corresponds to the hyperparameters of the PBT-ME algorithm.
Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitterConfig(PyTreeNode):
"""
Config for the PBT-ME emitter. This mainly corresponds to the hyperparameters
of the PBT-ME algorithm.
"""
buffer_size: int
num_training_iterations: int
env_batch_size: int
grad_updates_per_step: int
pg_population_size_per_device: int
ga_population_size_per_device: int
num_devices: int
fraction_best_to_replace_from: float
fraction_to_replace_from_best: float
fraction_to_replace_from_samples: float
# this fraction is used only for transfer between devices
fraction_sort_exchange: float
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/pbt_me_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
PBTEmitter (Emitter)
¶
A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.
Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitter(Emitter):
"""
A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
(PGA-Map-Elites) algorithm.
"""
def __init__(
self,
pbt_agent: Union[PBTSAC, PBTTD3],
config: PBTEmitterConfig,
env: QDEnv,
variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
) -> None:
# Parameters internalization
self._env = env
self._variation_fn = variation_fn
self._config = config
self._agent = pbt_agent
self._train_fn = self._agent.get_train_fn(
env=env,
num_iterations=config.num_training_iterations,
env_batch_size=config.env_batch_size,
grad_updates_per_step=config.grad_updates_per_step,
)
# Compute numbers from fractions
pg_population_size = config.pg_population_size_per_device * config.num_devices
self._num_best_to_replace_from = int(
pg_population_size * config.fraction_best_to_replace_from
)
self._num_to_replace_from_best = int(
pg_population_size * config.fraction_to_replace_from_best
)
self._num_to_replace_from_samples = int(
pg_population_size * config.fraction_to_replace_from_samples
)
self._num_to_exchange = int(
config.pg_population_size_per_device * config.fraction_sort_exchange
)
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[PBTEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
observation_size = self._env.observation_size
action_size = self._env.action_size
# Initialise replay buffers
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=self._config.pg_population_size_per_device
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=self._config.buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
# Initialise env states
(random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
env_states = jax.jit(self._env.reset)(rng=subkey1)
reshape_fn = jax.jit(
lambda tree: jax.tree_util.tree_map(
lambda x: jnp.reshape(
x,
(
self._config.pg_population_size_per_device,
self._config.env_batch_size,
)
+ x.shape[1:],
),
tree,
),
)
env_states = reshape_fn(env_states)
# Create emitter state
# keep only pg population size training states if more are provided
init_genotypes = jax.tree_util.tree_map(
lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
)
emitter_state = PBTEmitterState(
replay_buffers=replay_buffers,
env_states=env_states,
training_states=init_genotypes,
random_key=subkey2,
)
return emitter_state, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: PBTEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Do a single PGA-ME iteration: train critics and greedy policy,
make mutations (evo and pg), score solution, fill replay buffer and insert back
in the MAP-Elites grid.
Args:
repertoire: the current repertoire of genotypes
emitter_state: the state of the emitter used
random_key: a random key
Returns:
A batch of offspring, the new emitter state and a new key.
"""
# Mutation PG (the mutation has already been performed during the state update)
x_mutation_pg = emitter_state.training_states
# Mutation evo
if self._config.ga_population_size_per_device > 0:
mutation_ga_batch_size = self._config.ga_population_size_per_device
x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)
# Gather offspring
genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
x_mutation_ga,
x_mutation_pg,
)
else:
genotypes = x_mutation_pg
return genotypes, random_key
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
mutation_pg_batch_size = self._config.pg_population_size_per_device
mutation_ga_batch_size = self._config.ga_population_size_per_device
return mutation_pg_batch_size + mutation_ga_batch_size
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: PBTEmitterState,
repertoire: Repertoire,
genotypes: Optional[Genotype],
fitnesses: Fitness,
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> PBTEmitterState:
"""
Update the internal emitter state. I.e. update the population replay buffers and
agents.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# Look only at the fitness corresponding to emitter state individuals
fitnesses = fitnesses[self._config.ga_population_size_per_device :]
fitnesses = jnp.ravel(fitnesses)
training_states = emitter_state.training_states
replay_buffers = emitter_state.replay_buffers
genotypes = (training_states, replay_buffers)
# Incremental algorithm to gather top best among the population on each device
# First exchange
indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
num_best_local = int(
self._config.pg_population_size_per_device
* self._config.fraction_best_to_replace_from
)
indices_to_share = indices_to_share[:num_best_local]
genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
lambda x: x[indices_to_share], (genotypes, fitnesses)
)
gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes_to_share, fitnesses_to_share),
)
genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
best_indices_stacked = jnp.argsort(-fitnesses_stacked)
best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
)
# Define loop fn for the other exchanges
def _loop_fn(i, val): # type: ignore
best_genotypes_local, best_fitnesses_local = val
indices_to_share = jax.lax.dynamic_slice(
jnp.arange(self._config.pg_population_size_per_device),
[i * self._num_to_exchange],
[self._num_to_exchange],
)
genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
lambda x: x[indices_to_share], (genotypes, fitnesses)
)
gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes_to_share, fitnesses_to_share),
)
genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
(gathered_genotypes, gathered_fitnesses),
(best_genotypes_local, best_fitnesses_local),
)
best_indices_stacked = jnp.argsort(-fitnesses_stacked)
best_indices_stacked = best_indices_stacked[
: self._num_best_to_replace_from
]
best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
lambda x: x[best_indices_stacked],
(genotypes_stacked, fitnesses_stacked),
)
return (best_genotypes_local, best_fitnesses_local) # type: ignore
# Incrementally get the top fraction_best_to_replace_from best individuals
# on each device
(best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
lower=1,
upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
body_fun=_loop_fn,
init_val=(best_genotypes_local, best_fitnesses_local),
)
# Gather fitnesses from all devices to rank locally against it
all_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
fitnesses,
)
all_fitnesses = jnp.ravel(all_fitnesses)
all_fitnesses = -jnp.sort(-all_fitnesses)
random_key = emitter_state.random_key
random_key, sub_key = jax.random.split(random_key)
best_genotypes = jax.tree_util.tree_map(
lambda x: jax.random.choice(
sub_key, x, shape=(len(fitnesses),), replace=True
),
best_genotypes_local,
)
best_training_states, best_replay_buffers = best_genotypes
# Resample hyper-params
best_training_states = jax.vmap(
best_training_states.__class__.resample_hyperparams
)(best_training_states)
# Replace by individuals from the best
lower_bound = all_fitnesses[-self._num_to_replace_from_best]
cond = fitnesses <= lower_bound
training_states = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
best_training_states,
training_states,
)
replay_buffers = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
best_replay_buffers,
replay_buffers,
)
# Replacing with samples from the ME repertoire
if self._num_to_replace_from_samples > 0:
me_samples, random_key = repertoire.sample(
random_key, self._config.pg_population_size_per_device
)
# Resample hyper-params
me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
upper_bound = all_fitnesses[
-self._num_to_replace_from_best - self._num_to_replace_from_samples
]
cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
training_states = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
me_samples,
training_states,
)
# Train the agents
env_states = emitter_state.env_states
# Init optimizers state before training the population
training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
training_states
)
(training_states, env_states, replay_buffers), metrics = self._train_fn(
training_states, env_states, replay_buffers
)
# Empty optimizers states to avoid storing the info in the RAM
# and having too heavy repertoires
training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
training_states
)
# Update emitter state
emitter_state = emitter_state.replace(
training_states=training_states,
replay_buffers=replay_buffers,
env_states=env_states,
random_key=random_key,
)
return emitter_state # type: ignore
batch_size: int
property
readonly
¶
Returns: |
|
---|
init(self, init_genotypes, random_key)
¶
Initializes the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/pbt_me_emitter.py
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[PBTEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
observation_size = self._env.observation_size
action_size = self._env.action_size
# Initialise replay buffers
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=self._config.pg_population_size_per_device
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=self._config.buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
# Initialise env states
(random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
env_states = jax.jit(self._env.reset)(rng=subkey1)
reshape_fn = jax.jit(
lambda tree: jax.tree_util.tree_map(
lambda x: jnp.reshape(
x,
(
self._config.pg_population_size_per_device,
self._config.env_batch_size,
)
+ x.shape[1:],
),
tree,
),
)
env_states = reshape_fn(env_states)
# Create emitter state
# keep only pg population size training states if more are provided
init_genotypes = jax.tree_util.tree_map(
lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
)
emitter_state = PBTEmitterState(
replay_buffers=replay_buffers,
env_states=env_states,
training_states=init_genotypes,
random_key=subkey2,
)
return emitter_state, random_key
emit(self, repertoire, emitter_state, random_key)
¶
Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: PBTEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Do a single PGA-ME iteration: train critics and greedy policy,
make mutations (evo and pg), score solution, fill replay buffer and insert back
in the MAP-Elites grid.
Args:
repertoire: the current repertoire of genotypes
emitter_state: the state of the emitter used
random_key: a random key
Returns:
A batch of offspring, the new emitter state and a new key.
"""
# Mutation PG (the mutation has already been performed during the state update)
x_mutation_pg = emitter_state.training_states
# Mutation evo
if self._config.ga_population_size_per_device > 0:
mutation_ga_batch_size = self._config.ga_population_size_per_device
x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)
# Gather offspring
genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
x_mutation_ga,
x_mutation_pg,
)
else:
genotypes = x_mutation_pg
return genotypes, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
Update the internal emitter state. I.e. update the population replay buffers and agents.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: PBTEmitterState,
repertoire: Repertoire,
genotypes: Optional[Genotype],
fitnesses: Fitness,
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> PBTEmitterState:
"""
Update the internal emitter state. I.e. update the population replay buffers and
agents.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# Look only at the fitness corresponding to emitter state individuals
fitnesses = fitnesses[self._config.ga_population_size_per_device :]
fitnesses = jnp.ravel(fitnesses)
training_states = emitter_state.training_states
replay_buffers = emitter_state.replay_buffers
genotypes = (training_states, replay_buffers)
# Incremental algorithm to gather top best among the population on each device
# First exchange
indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
num_best_local = int(
self._config.pg_population_size_per_device
* self._config.fraction_best_to_replace_from
)
indices_to_share = indices_to_share[:num_best_local]
genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
lambda x: x[indices_to_share], (genotypes, fitnesses)
)
gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes_to_share, fitnesses_to_share),
)
genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
best_indices_stacked = jnp.argsort(-fitnesses_stacked)
best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
)
# Define loop fn for the other exchanges
def _loop_fn(i, val): # type: ignore
best_genotypes_local, best_fitnesses_local = val
indices_to_share = jax.lax.dynamic_slice(
jnp.arange(self._config.pg_population_size_per_device),
[i * self._num_to_exchange],
[self._num_to_exchange],
)
genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
lambda x: x[indices_to_share], (genotypes, fitnesses)
)
gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(genotypes_to_share, fitnesses_to_share),
)
genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
(gathered_genotypes, gathered_fitnesses),
(best_genotypes_local, best_fitnesses_local),
)
best_indices_stacked = jnp.argsort(-fitnesses_stacked)
best_indices_stacked = best_indices_stacked[
: self._num_best_to_replace_from
]
best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
lambda x: x[best_indices_stacked],
(genotypes_stacked, fitnesses_stacked),
)
return (best_genotypes_local, best_fitnesses_local) # type: ignore
# Incrementally get the top fraction_best_to_replace_from best individuals
# on each device
(best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
lower=1,
upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
body_fun=_loop_fn,
init_val=(best_genotypes_local, best_fitnesses_local),
)
# Gather fitnesses from all devices to rank locally against it
all_fitnesses = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
fitnesses,
)
all_fitnesses = jnp.ravel(all_fitnesses)
all_fitnesses = -jnp.sort(-all_fitnesses)
random_key = emitter_state.random_key
random_key, sub_key = jax.random.split(random_key)
best_genotypes = jax.tree_util.tree_map(
lambda x: jax.random.choice(
sub_key, x, shape=(len(fitnesses),), replace=True
),
best_genotypes_local,
)
best_training_states, best_replay_buffers = best_genotypes
# Resample hyper-params
best_training_states = jax.vmap(
best_training_states.__class__.resample_hyperparams
)(best_training_states)
# Replace by individuals from the best
lower_bound = all_fitnesses[-self._num_to_replace_from_best]
cond = fitnesses <= lower_bound
training_states = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
best_training_states,
training_states,
)
replay_buffers = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
best_replay_buffers,
replay_buffers,
)
# Replacing with samples from the ME repertoire
if self._num_to_replace_from_samples > 0:
me_samples, random_key = repertoire.sample(
random_key, self._config.pg_population_size_per_device
)
# Resample hyper-params
me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
upper_bound = all_fitnesses[
-self._num_to_replace_from_best - self._num_to_replace_from_samples
]
cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
training_states = jax.tree_util.tree_map(
lambda x, y: jnp.where(
jnp.expand_dims(
cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
),
x,
y,
),
me_samples,
training_states,
)
# Train the agents
env_states = emitter_state.env_states
# Init optimizers state before training the population
training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
training_states
)
(training_states, env_states, replay_buffers), metrics = self._train_fn(
training_states, env_states, replay_buffers
)
# Empty optimizers states to avoid storing the info in the RAM
# and having too heavy repertoires
training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
training_states
)
# Update emitter state
emitter_state = emitter_state.replace(
training_states=training_states,
replay_buffers=replay_buffers,
env_states=env_states,
random_key=random_key,
)
return emitter_state # type: ignore
pbt_variation_operators
¶
sac_pbt_variation_fn(training_state1, training_state2, random_key, iso_sigma, line_sigma)
¶
This operator runs a cross-over between two SAC agents. It is used as variation operator in the SAC-PBT-Map-Elites algorithm. An isoline-dd variation is applied to policy networks, critic networks and entropy alpha coefficients.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/pbt_variation_operators.py
def sac_pbt_variation_fn(
training_state1: PBTSacTrainingState,
training_state2: PBTSacTrainingState,
random_key: RNGKey,
iso_sigma: float,
line_sigma: float,
) -> Tuple[PBTSacTrainingState, RNGKey]:
"""
This operator runs a cross-over between two SAC agents. It is used as variation
operator in the SAC-PBT-Map-Elites algorithm. An isoline-dd variation is applied
to policy networks, critic networks and entropy alpha coefficients.
Args:
training_state1: Training state of first SAC agent.
training_state2: Training state of first SAC agent.
random_key: Random key.
iso_sigma: Spread parameter (noise).
line_sigma: Line parameter (direction of the new genotype).
Returns:
A new SAC training state obtained from cross-over and an updated random key.
"""
policy_params1, policy_params2 = (
training_state1.policy_params,
training_state2.policy_params,
)
critic_params1, critic_params2 = (
training_state1.critic_params,
training_state2.critic_params,
)
alpha_params1, alpha_params2 = (
training_state1.alpha_params,
training_state2.alpha_params,
)
(policy_params, critic_params, alpha_params), random_key = isoline_variation(
x1=(policy_params1, critic_params1, alpha_params1),
x2=(policy_params2, critic_params2, alpha_params2),
random_key=random_key,
iso_sigma=iso_sigma,
line_sigma=line_sigma,
)
new_training_state = training_state1.replace(
policy_params=policy_params,
critic_params=critic_params,
alpha_params=alpha_params,
)
return (
new_training_state,
random_key,
)
td3_pbt_variation_fn(training_state1, training_state2, random_key, iso_sigma, line_sigma)
¶
This operator runs a cross-over between two TD3 agents. It is used as variation operator in the TD3-PBT-Map-Elites algorithm. An isoline-dd variation is applied to policy networks and critic networks.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/pbt_variation_operators.py
def td3_pbt_variation_fn(
training_state1: PBTTD3TrainingState,
training_state2: PBTTD3TrainingState,
random_key: RNGKey,
iso_sigma: float,
line_sigma: float,
) -> Tuple[PBTTD3TrainingState, RNGKey]:
"""
This operator runs a cross-over between two TD3 agents. It is used as variation
operator in the TD3-PBT-Map-Elites algorithm. An isoline-dd variation is applied
to policy networks and critic networks.
Args:
training_state1: Training state of first TD3 agent.
training_state2: Training state of first TD3 agent.
random_key: Random key.
iso_sigma: Spread parameter (noise).
line_sigma: Line parameter (direction of the new genotype).
Returns:
A new TD3 training state obtained from cross-over and an updated random key.
"""
policy_params1, policy_params2 = (
training_state1.policy_params,
training_state2.policy_params,
)
critic_params1, critic_params2 = (
training_state1.critic_params,
training_state2.critic_params,
)
(policy_params, critic_params,), random_key = isoline_variation(
x1=(policy_params1, critic_params1),
x2=(policy_params2, critic_params2),
random_key=random_key,
iso_sigma=iso_sigma,
line_sigma=line_sigma,
)
new_training_state = training_state1.replace(
policy_params=policy_params,
critic_params=critic_params,
)
return (
new_training_state,
random_key,
)
pga_me_emitter
¶
PGAMEConfig
dataclass
¶
Configuration for PGAME Algorithm
Source code in qdax/core/emitters/pga_me_emitter.py
@dataclass
class PGAMEConfig:
"""Configuration for PGAME Algorithm"""
env_batch_size: int = 100
proportion_mutation_ga: float = 0.5
num_critic_training_steps: int = 300
num_pg_training_steps: int = 100
# TD3 params
replay_buffer_size: int = 1000000
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
critic_learning_rate: float = 3e-4
greedy_learning_rate: float = 3e-4
policy_learning_rate: float = 1e-3
noise_clip: float = 0.5
policy_noise: float = 0.2
discount: float = 0.99
reward_scaling: float = 1.0
batch_size: int = 256
soft_tau_update: float = 0.005
policy_delay: int = 2
qdpg_emitter
¶
Implementation of an updated version of the algorithm QDPG presented in the paper https://arxiv.org/abs/2006.08505.
QDPG has been udpated to enter in the container+emitter framework of QD. Furthermore, it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm.
QDPGEmitterConfig
dataclass
¶
QDPGEmitterConfig(qpg_config: qdax.core.emitters.qpg_emitter.QualityPGConfig, dpg_config: qdax.core.emitters.dpg_emitter.DiversityPGConfig, iso_sigma: float, line_sigma: float, ga_batch_size: int)
Source code in qdax/core/emitters/qdpg_emitter.py
@dataclass
class QDPGEmitterConfig:
qpg_config: QualityPGConfig
dpg_config: DiversityPGConfig
iso_sigma: float
line_sigma: float
ga_batch_size: int
qpg_emitter
¶
Implements the PG Emitter from PGA-ME algorithm in jax for brax environments, based on: https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf
QualityPGConfig
dataclass
¶
Configuration for QualityPG Emitter
Source code in qdax/core/emitters/qpg_emitter.py
@dataclass
class QualityPGConfig:
"""Configuration for QualityPG Emitter"""
env_batch_size: int = 100
num_critic_training_steps: int = 300
num_pg_training_steps: int = 100
# TD3 params
replay_buffer_size: int = 1000000
critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
critic_learning_rate: float = 3e-4
actor_learning_rate: float = 3e-4
policy_learning_rate: float = 1e-3
noise_clip: float = 0.5
policy_noise: float = 0.2
discount: float = 0.99
reward_scaling: float = 1.0
batch_size: int = 256
soft_tau_update: float = 0.005
policy_delay: int = 2
QualityPGEmitterState (EmitterState)
dataclass
¶
Contains training state for the learner.
Source code in qdax/core/emitters/qpg_emitter.py
class QualityPGEmitterState(EmitterState):
"""Contains training state for the learner."""
critic_params: Params
critic_optimizer_state: optax.OptState
actor_params: Params
actor_opt_state: optax.OptState
target_critic_params: Params
target_actor_params: Params
replay_buffer: ReplayBuffer
random_key: RNGKey
steps: jnp.ndarray
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/emitters/qpg_emitter.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
QualityPGEmitter (Emitter)
¶
A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.
Source code in qdax/core/emitters/qpg_emitter.py
class QualityPGEmitter(Emitter):
"""
A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
(PGA-Map-Elites) algorithm.
"""
def __init__(
self,
config: QualityPGConfig,
policy_network: nn.Module,
env: QDEnv,
) -> None:
self._config = config
self._env = env
self._policy_network = policy_network
# Init Critics
critic_network = QModule(
n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size
)
self._critic_network = critic_network
# Set up the losses and optimizers - return the opt states
self._policy_loss_fn, self._critic_loss_fn = make_td3_loss_fn(
policy_fn=policy_network.apply,
critic_fn=critic_network.apply,
reward_scaling=self._config.reward_scaling,
discount=self._config.discount,
noise_clip=self._config.noise_clip,
policy_noise=self._config.policy_noise,
)
# Init optimizers
self._actor_optimizer = optax.adam(
learning_rate=self._config.actor_learning_rate
)
self._critic_optimizer = optax.adam(
learning_rate=self._config.critic_learning_rate
)
self._policies_optimizer = optax.adam(
learning_rate=self._config.policy_learning_rate
)
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._config.env_batch_size
@property
def use_all_data(self) -> bool:
"""Whether to use all data or not when used along other emitters.
QualityPGEmitter uses the transitions from the genotypes that were generated
by other emitters.
"""
return True
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[QualityPGEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
observation_size = self._env.observation_size
action_size = self._env.action_size
descriptor_size = self._env.state_descriptor_length
# Initialise critic, greedy actor and population
random_key, subkey = jax.random.split(random_key)
fake_obs = jnp.zeros(shape=(observation_size,))
fake_action = jnp.zeros(shape=(action_size,))
critic_params = self._critic_network.init(
subkey, obs=fake_obs, actions=fake_action
)
target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params)
actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
# Prepare init optimizer states
critic_optimizer_state = self._critic_optimizer.init(critic_params)
actor_optimizer_state = self._actor_optimizer.init(actor_params)
# Initialize replay buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=descriptor_size,
)
replay_buffer = ReplayBuffer.init(
buffer_size=self._config.replay_buffer_size, transition=dummy_transition
)
# Initial training state
random_key, subkey = jax.random.split(random_key)
emitter_state = QualityPGEmitterState(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
actor_params=actor_params,
actor_opt_state=actor_optimizer_state,
target_critic_params=target_critic_params,
target_actor_params=target_actor_params,
random_key=subkey,
steps=jnp.array(0),
replay_buffer=replay_buffer,
)
return emitter_state, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: QualityPGEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Do a step of PG emission.
Args:
repertoire: the current repertoire of genotypes
emitter_state: the state of the emitter used
random_key: a random key
Returns:
A batch of offspring, the new emitter state and a new key.
"""
batch_size = self._config.env_batch_size
# sample parents
mutation_pg_batch_size = int(batch_size - 1)
parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size)
# apply the pg mutation
offsprings_pg = self.emit_pg(emitter_state, parents)
# get the actor (greedy actor)
offspring_actor = self.emit_actor(emitter_state)
# add dimension for concatenation
offspring_actor = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x, axis=0), offspring_actor
)
# gather offspring
genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
offsprings_pg,
offspring_actor,
)
return genotypes, random_key
@partial(
jax.jit,
static_argnames=("self",),
)
def emit_pg(
self, emitter_state: QualityPGEmitterState, parents: Genotype
) -> Genotype:
"""Emit the offsprings generated through pg mutation.
Args:
emitter_state: current emitter state, contains critic and
replay buffer.
parents: the parents selected to be applied gradients in order
to mutate towards better performance.
Returns:
A new set of offsprings.
"""
mutation_fn = partial(
self._mutation_function_pg,
emitter_state=emitter_state,
)
offsprings = jax.vmap(mutation_fn)(parents)
return offsprings
@partial(
jax.jit,
static_argnames=("self",),
)
def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
"""Emit the greedy actor.
Simply needs to be retrieved from the emitter state.
Args:
emitter_state: the current emitter state, it stores the
greedy actor.
Returns:
The parameters of the actor.
"""
return emitter_state.actor_params
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: QualityPGEmitterState,
repertoire: Optional[Repertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> QualityPGEmitterState:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions
from the scoring of the genotypes, and then the training of the
critic/actor happens. Hence the params of critic/actor are updated,
as well as their optimizer states.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# get the transitions out of the dictionary
assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
transitions = extra_scores["transitions"]
# add transitions in the replay buffer
replay_buffer = emitter_state.replay_buffer.insert(transitions)
emitter_state = emitter_state.replace(replay_buffer=replay_buffer)
def scan_train_critics(
carry: QualityPGEmitterState, unused: Any
) -> Tuple[QualityPGEmitterState, Any]:
emitter_state = carry
new_emitter_state = self._train_critics(emitter_state)
return new_emitter_state, ()
# Train critics and greedy actor
emitter_state, _ = jax.lax.scan(
scan_train_critics,
emitter_state,
(),
length=self._config.num_critic_training_steps,
)
return emitter_state # type: ignore
@partial(jax.jit, static_argnames=("self",))
def _train_critics(
self, emitter_state: QualityPGEmitterState
) -> QualityPGEmitterState:
"""Apply one gradient step to critics and to the greedy actor
(contained in carry in training_state), then soft update target critics
and target actor.
Those updates are very similar to those made in TD3.
Args:
emitter_state: actual emitter state
Returns:
New emitter state where the critic and the greedy actor have been
updated. Optimizer states have also been updated in the process.
"""
# Sample a batch of transitions in the buffer
random_key = emitter_state.random_key
replay_buffer = emitter_state.replay_buffer
transitions, random_key = replay_buffer.sample(
random_key, sample_size=self._config.batch_size
)
# Update Critic
(
critic_optimizer_state,
critic_params,
target_critic_params,
random_key,
) = self._update_critic(
critic_params=emitter_state.critic_params,
target_critic_params=emitter_state.target_critic_params,
target_actor_params=emitter_state.target_actor_params,
critic_optimizer_state=emitter_state.critic_optimizer_state,
transitions=transitions,
random_key=random_key,
)
# Update greedy actor
(actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond(
emitter_state.steps % self._config.policy_delay == 0,
lambda x: self._update_actor(*x),
lambda _: (
emitter_state.actor_opt_state,
emitter_state.actor_params,
emitter_state.target_actor_params,
),
operand=(
emitter_state.actor_params,
emitter_state.actor_opt_state,
emitter_state.target_actor_params,
emitter_state.critic_params,
transitions,
),
)
# Create new training state
new_emitter_state = emitter_state.replace(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
actor_params=actor_params,
actor_opt_state=actor_optimizer_state,
target_critic_params=target_critic_params,
target_actor_params=target_actor_params,
random_key=random_key,
steps=emitter_state.steps + 1,
replay_buffer=replay_buffer,
)
return new_emitter_state # type: ignore
@partial(jax.jit, static_argnames=("self",))
def _update_critic(
self,
critic_params: Params,
target_critic_params: Params,
target_actor_params: Params,
critic_optimizer_state: Params,
transitions: QDTransition,
random_key: RNGKey,
) -> Tuple[Params, Params, Params, RNGKey]:
# compute loss and gradients
random_key, subkey = jax.random.split(random_key)
critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)(
critic_params,
target_actor_params,
target_critic_params,
transitions,
subkey,
)
critic_updates, critic_optimizer_state = self._critic_optimizer.update(
critic_gradient, critic_optimizer_state
)
# update critic
critic_params = optax.apply_updates(critic_params, critic_updates)
# Soft update of target critic network
target_critic_params = jax.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
target_critic_params,
critic_params,
)
return critic_optimizer_state, critic_params, target_critic_params, random_key
@partial(jax.jit, static_argnames=("self",))
def _update_actor(
self,
actor_params: Params,
actor_opt_state: optax.OptState,
target_actor_params: Params,
critic_params: Params,
transitions: QDTransition,
) -> Tuple[optax.OptState, Params, Params]:
# Update greedy actor
policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
actor_params,
critic_params,
transitions,
)
(
policy_updates,
actor_optimizer_state,
) = self._actor_optimizer.update(policy_gradient, actor_opt_state)
actor_params = optax.apply_updates(actor_params, policy_updates)
# Soft update of target greedy actor
target_actor_params = jax.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
target_actor_params,
actor_params,
)
return (
actor_optimizer_state,
actor_params,
target_actor_params,
)
@partial(jax.jit, static_argnames=("self",))
def _mutation_function_pg(
self,
policy_params: Genotype,
emitter_state: QualityPGEmitterState,
) -> Genotype:
"""Apply pg mutation to a policy via multiple steps of gradient descent.
First, update the rewards to be diversity rewards, then apply the gradient
steps.
Args:
policy_params: a policy, supposed to be a differentiable neural
network.
emitter_state: the current state of the emitter, containing among others,
the replay buffer, the critic.
Returns:
The updated params of the neural network.
"""
# Define new policy optimizer state
policy_optimizer_state = self._policies_optimizer.init(policy_params)
def scan_train_policy(
carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState],
unused: Any,
) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]:
emitter_state, policy_params, policy_optimizer_state = carry
(
new_emitter_state,
new_policy_params,
new_policy_optimizer_state,
) = self._train_policy(
emitter_state,
policy_params,
policy_optimizer_state,
)
return (
new_emitter_state,
new_policy_params,
new_policy_optimizer_state,
), ()
(emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan(
scan_train_policy,
(emitter_state, policy_params, policy_optimizer_state),
(),
length=self._config.num_pg_training_steps,
)
return policy_params
@partial(jax.jit, static_argnames=("self",))
def _train_policy(
self,
emitter_state: QualityPGEmitterState,
policy_params: Params,
policy_optimizer_state: optax.OptState,
) -> Tuple[QualityPGEmitterState, Params, optax.OptState]:
"""Apply one gradient step to a policy (called policy_params).
Args:
emitter_state: current state of the emitter.
policy_params: parameters corresponding to the weights and bias of
the neural network that defines the policy.
Returns:
The new emitter state and new params of the NN.
"""
# Sample a batch of transitions in the buffer
random_key = emitter_state.random_key
replay_buffer = emitter_state.replay_buffer
transitions, random_key = replay_buffer.sample(
random_key, sample_size=self._config.batch_size
)
# update policy
policy_optimizer_state, policy_params = self._update_policy(
critic_params=emitter_state.critic_params,
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
transitions=transitions,
)
# Create new training state
new_emitter_state = emitter_state.replace(
random_key=random_key,
replay_buffer=replay_buffer,
)
return new_emitter_state, policy_params, policy_optimizer_state
@partial(jax.jit, static_argnames=("self",))
def _update_policy(
self,
critic_params: Params,
policy_optimizer_state: optax.OptState,
policy_params: Params,
transitions: QDTransition,
) -> Tuple[optax.OptState, Params]:
# compute loss
_policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
policy_params,
critic_params,
transitions,
)
# Compute gradient and update policies
(
policy_updates,
policy_optimizer_state,
) = self._policies_optimizer.update(policy_gradient, policy_optimizer_state)
policy_params = optax.apply_updates(policy_params, policy_updates)
return policy_optimizer_state, policy_params
batch_size: int
property
readonly
¶
Returns: |
|
---|
use_all_data: bool
property
readonly
¶
Whether to use all data or not when used along other emitters.
QualityPGEmitter uses the transitions from the genotypes that were generated by other emitters.
init(self, init_genotypes, random_key)
¶
Initializes the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/qpg_emitter.py
def init(
self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[QualityPGEmitterState, RNGKey]:
"""Initializes the emitter state.
Args:
init_genotypes: The initial population.
random_key: A random key.
Returns:
The initial state of the PGAMEEmitter, a new random key.
"""
observation_size = self._env.observation_size
action_size = self._env.action_size
descriptor_size = self._env.state_descriptor_length
# Initialise critic, greedy actor and population
random_key, subkey = jax.random.split(random_key)
fake_obs = jnp.zeros(shape=(observation_size,))
fake_action = jnp.zeros(shape=(action_size,))
critic_params = self._critic_network.init(
subkey, obs=fake_obs, actions=fake_action
)
target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params)
actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
# Prepare init optimizer states
critic_optimizer_state = self._critic_optimizer.init(critic_params)
actor_optimizer_state = self._actor_optimizer.init(actor_params)
# Initialize replay buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=observation_size,
action_dim=action_size,
descriptor_dim=descriptor_size,
)
replay_buffer = ReplayBuffer.init(
buffer_size=self._config.replay_buffer_size, transition=dummy_transition
)
# Initial training state
random_key, subkey = jax.random.split(random_key)
emitter_state = QualityPGEmitterState(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
actor_params=actor_params,
actor_opt_state=actor_optimizer_state,
target_critic_params=target_critic_params,
target_actor_params=target_actor_params,
random_key=subkey,
steps=jnp.array(0),
replay_buffer=replay_buffer,
)
return emitter_state, random_key
emit(self, repertoire, emitter_state, random_key)
¶
Do a step of PG emission.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/qpg_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: QualityPGEmitterState,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""Do a step of PG emission.
Args:
repertoire: the current repertoire of genotypes
emitter_state: the state of the emitter used
random_key: a random key
Returns:
A batch of offspring, the new emitter state and a new key.
"""
batch_size = self._config.env_batch_size
# sample parents
mutation_pg_batch_size = int(batch_size - 1)
parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size)
# apply the pg mutation
offsprings_pg = self.emit_pg(emitter_state, parents)
# get the actor (greedy actor)
offspring_actor = self.emit_actor(emitter_state)
# add dimension for concatenation
offspring_actor = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x, axis=0), offspring_actor
)
# gather offspring
genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
offsprings_pg,
offspring_actor,
)
return genotypes, random_key
emit_pg(self, emitter_state, parents)
¶
Emit the offsprings generated through pg mutation.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/qpg_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit_pg(
self, emitter_state: QualityPGEmitterState, parents: Genotype
) -> Genotype:
"""Emit the offsprings generated through pg mutation.
Args:
emitter_state: current emitter state, contains critic and
replay buffer.
parents: the parents selected to be applied gradients in order
to mutate towards better performance.
Returns:
A new set of offsprings.
"""
mutation_fn = partial(
self._mutation_function_pg,
emitter_state=emitter_state,
)
offsprings = jax.vmap(mutation_fn)(parents)
return offsprings
emit_actor(self, emitter_state)
¶
Emit the greedy actor.
Simply needs to be retrieved from the emitter state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/qpg_emitter.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
"""Emit the greedy actor.
Simply needs to be retrieved from the emitter state.
Args:
emitter_state: the current emitter state, it stores the
greedy actor.
Returns:
The parameters of the actor.
"""
return emitter_state.actor_params
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)
¶
This function gives an opportunity to update the emitter state after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions from the scoring of the genotypes, and then the training of the critic/actor happens. Hence the params of critic/actor are updated, as well as their optimizer states.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/qpg_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
self,
emitter_state: QualityPGEmitterState,
repertoire: Optional[Repertoire],
genotypes: Optional[Genotype],
fitnesses: Optional[Fitness],
descriptors: Optional[Descriptor],
extra_scores: ExtraScores,
) -> QualityPGEmitterState:
"""This function gives an opportunity to update the emitter state
after the genotypes have been scored.
Here it is used to fill the Replay Buffer with the transitions
from the scoring of the genotypes, and then the training of the
critic/actor happens. Hence the params of critic/actor are updated,
as well as their optimizer states.
Args:
emitter_state: current emitter state.
repertoire: the current genotypes repertoire
genotypes: unused here - but compulsory in the signature.
fitnesses: unused here - but compulsory in the signature.
descriptors: unused here - but compulsory in the signature.
extra_scores: extra information coming from the scoring function,
this contains the transitions added to the replay buffer.
Returns:
New emitter state where the replay buffer has been filled with
the new experienced transitions.
"""
# get the transitions out of the dictionary
assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
transitions = extra_scores["transitions"]
# add transitions in the replay buffer
replay_buffer = emitter_state.replay_buffer.insert(transitions)
emitter_state = emitter_state.replace(replay_buffer=replay_buffer)
def scan_train_critics(
carry: QualityPGEmitterState, unused: Any
) -> Tuple[QualityPGEmitterState, Any]:
emitter_state = carry
new_emitter_state = self._train_critics(emitter_state)
return new_emitter_state, ()
# Train critics and greedy actor
emitter_state, _ = jax.lax.scan(
scan_train_critics,
emitter_state,
(),
length=self._config.num_critic_training_steps,
)
return emitter_state # type: ignore
standard_emitters
¶
MixingEmitter (Emitter)
¶
Source code in qdax/core/emitters/standard_emitters.py
class MixingEmitter(Emitter):
def __init__(
self,
mutation_fn: Callable[[Genotype, RNGKey], Tuple[Genotype, RNGKey]],
variation_fn: Callable[[Genotype, Genotype, RNGKey], Tuple[Genotype, RNGKey]],
variation_percentage: float,
batch_size: int,
) -> None:
self._mutation_fn = mutation_fn
self._variation_fn = variation_fn
self._variation_percentage = variation_percentage
self._batch_size = batch_size
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emitter that performs both mutation and variation. Two batches of
variation_percentage * batch_size genotypes are sampled in the repertoire,
copied and cross-over to obtain new offsprings. One batch of
(1.0 - variation_percentage) * batch_size genotypes are sampled in the
repertoire, copied and mutated.
Note: this emitter has no state. A fake none state must be added
through a function redefinition to make this emitter usable with MAP-Elites.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state: void
random_key: a jax PRNG random key
Returns:
a batch of offsprings
a new jax PRNG key
"""
n_variation = int(self._batch_size * self._variation_percentage)
n_mutation = self._batch_size - n_variation
if n_variation > 0:
x1, random_key = repertoire.sample(random_key, n_variation)
x2, random_key = repertoire.sample(random_key, n_variation)
x_variation, random_key = self._variation_fn(x1, x2, random_key)
if n_mutation > 0:
x1, random_key = repertoire.sample(random_key, n_mutation)
x_mutation, random_key = self._mutation_fn(x1, random_key)
if n_variation == 0:
genotypes = x_mutation
elif n_mutation == 0:
genotypes = x_variation
else:
genotypes = jax.tree_util.tree_map(
lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
x_variation,
x_mutation,
)
return genotypes, random_key
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._batch_size
batch_size: int
property
readonly
¶
Returns: |
|
---|
emit(self, repertoire, emitter_state, random_key)
¶
Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, copied and cross-over to obtain new offsprings. One batch of (1.0 - variation_percentage) * batch_size genotypes are sampled in the repertoire, copied and mutated.
Note: this emitter has no state. A fake none state must be added through a function redefinition to make this emitter usable with MAP-Elites.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/emitters/standard_emitters.py
@partial(
jax.jit,
static_argnames=("self",),
)
def emit(
self,
repertoire: Repertoire,
emitter_state: Optional[EmitterState],
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Emitter that performs both mutation and variation. Two batches of
variation_percentage * batch_size genotypes are sampled in the repertoire,
copied and cross-over to obtain new offsprings. One batch of
(1.0 - variation_percentage) * batch_size genotypes are sampled in the
repertoire, copied and mutated.
Note: this emitter has no state. A fake none state must be added
through a function redefinition to make this emitter usable with MAP-Elites.
Params:
repertoire: the MAP-Elites repertoire to sample from
emitter_state: void
random_key: a jax PRNG random key
Returns:
a batch of offsprings
a new jax PRNG key
"""
n_variation = int(self._batch_size * self._variation_percentage)
n_mutation = self._batch_size - n_variation
if n_variation > 0:
x1, random_key = repertoire.sample(random_key, n_variation)
x2, random_key = repertoire.sample(random_key, n_variation)
x_variation, random_key = self._variation_fn(x1, x2, random_key)
if n_mutation > 0:
x1, random_key = repertoire.sample(random_key, n_mutation)
x_mutation, random_key = self._mutation_fn(x1, random_key)
if n_variation == 0:
genotypes = x_mutation
elif n_mutation == 0:
genotypes = x_variation
else:
genotypes = jax.tree_util.tree_map(
lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
x_variation,
x_mutation,
)
return genotypes, random_key