Containers¶
qdax.core.containers
special
¶
archive
¶
Defines an unstructured archive and a euclidean novelty scorer.
Archive (PyTreeNode)
dataclass
¶
Stores jnp.ndarray in a way that makes insertion jittable.
An example of use of the archive is the algorithm QDPG: state descriptors are stored in this archive and a novelty scorer compares new state desciptors to the state descriptors stored in this archive.
Note: notations suppose that the elements are called state desciptors. If we where to use this structure for another application, it would be better to change the variables name for another one. Does not seem necessary at the moment though.
Source code in qdax/core/containers/archive.py
class Archive(PyTreeNode):
"""Stores jnp.ndarray in a way that makes insertion jittable.
An example of use of the archive is the algorithm QDPG: state
descriptors are stored in this archive and a novelty scorer compares
new state desciptors to the state descriptors stored in this archive.
Note: notations suppose that the elements are called state desciptors.
If we where to use this structure for another application, it would be
better to change the variables name for another one. Does not seem
necessary at the moment though.
"""
data: jnp.ndarray # initialised with nan everywhere
current_position: int
acceptance_threshold: float
state_descriptor_size: int
max_size: int
@property
def size(self) -> float:
"""Compute the number of state descriptors stored in the archive.
Returns:
Size of the archive.
"""
# remove fake borders
fake_data = jnp.isnan(self.data)
# count number of real data
return sum(~fake_data)
@classmethod
def create(
cls,
acceptance_threshold: float,
state_descriptor_size: int,
max_size: int,
) -> Archive:
"""Create an Archive instance.
This class method provides a convenient way to create the archive while
keeping the __init__ function for more general way to init an archive.
Args:
acceptance_threshold: the minimal distance to a stored descriptor to
be respected for a new descriptor to be added.
state_descriptor_size: the number of elements in a state descriptor.
max_size: the maximal size of the archive. In case of overflow, previous
elements are replaced by new ones. Defaults to 80000.
Returns:
A newly initialized archive.
"""
init_data = jnp.ones((max_size, state_descriptor_size)) * jnp.nan
return cls( # type: ignore
data=init_data,
current_position=0,
acceptance_threshold=acceptance_threshold,
state_descriptor_size=state_descriptor_size,
max_size=max_size,
)
@jax.jit
def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive:
"""Insert a single element.
If the archive is not full yet, the new element replaces a fake
border, if it is full, it replaces the element that was inserted
first in the archive.
Args:
state_descriptor: state descriptor to be added.
Returns:
Return the archive with the newly added element.
"""
new_current_position = self.current_position + 1
new_data = jax.lax.dynamic_update_slice_in_dim(
self.data,
state_descriptor.reshape(1, -1),
start_index=self.current_position % self.max_size,
axis=0,
)
return self.replace( # type: ignore
current_position=new_current_position, data=new_data
)
@jax.jit
def _conditioned_single_insertion(
self, condition: bool, state_descriptor: jnp.ndarray
) -> Tuple[Archive, jnp.ndarray]:
"""Inserts a single element under a condition.
The function also retrieves the added elements.
Args:
condition: condition for being added in the archive.
state_descriptor: state descriptor to be added under the
given condition.
Returns:
The new archive and the added elements.
"""
def true_fun(
archive: Archive, state_descriptor: jnp.ndarray
) -> Tuple[Archive, jnp.ndarray]:
return archive._single_insertion(state_descriptor), state_descriptor
def false_fun(
archive: Archive, state_descriptor: jnp.ndarray
) -> Tuple[Archive, jnp.ndarray]:
return archive, jnp.ones_like(state_descriptor) * jnp.nan
return jax.lax.cond( # type: ignore
condition, true_fun, false_fun, self, state_descriptor
)
@jax.jit
def insert(self, state_descriptors: jnp.ndarray) -> Archive:
"""Tries to insert a batch of state descriptors in the archive.
1. First, look at the distance of each new state descriptor with the
already stored ones.
2. Then, scan the state descriptors, check the distance with
the descriptors inserted during the scan.
3. If the state descriptor verified the first condition (not too close
to a state descriptor in the old archive) and the second (not too close
from a state descriptor that has just been added), then it is added
to the archive.
Note 1: the archive has a fixed size, hence, in case of overflow, the
first elements added are removed first (FIFO style).
Note 2: keep in mind that fake descriptors are used to help keep the size
constant. Those correspond to a descriptor very far away from the typical
values of the problem at hand.
Args:
state_descriptors: state descriptors to be added.
Returns:
New archive updated with the state descriptors.
"""
state_descriptors = state_descriptors.reshape((-1, state_descriptors.shape[-1]))
# get nearest neigbor for each new state descriptor
values, _indices = knn(self.data, state_descriptors, 1)
# get indices where distance bigger than threshold
relevant_indices = jnp.where(
values.squeeze() > self.acceptance_threshold, x=0, y=1
)
def iterate_fn(
carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict
) -> Tuple[Tuple[Archive, jnp.ndarray, int], Any]:
"""Iterates over the archive to add elements one after the other.
Args:
carry: tuple containing the archive, the state descriptors and the
indices.
condition_data: the first addition condition of the state descriptors
given, which corresponds to being sufficiently far away from already
stored descriptors.
Returns:
The update tuple.
"""
archive, new_elements, index = carry
first_condition = condition_data["condition"]
state_descriptor = condition_data["state_descriptor"]
# do the filtering among the added elements
# get nearest neigbor for each new state descriptor
values, _indices = knn(new_elements, state_descriptor.reshape(1, -1), 1)
# get indices where distance bigger than threshold
not_too_close = jnp.where(
values.squeeze() > self.acceptance_threshold, x=0, y=1
)
second_condition = not_too_close.sum()
condition = (first_condition + second_condition) == 0
new_archive, added_element = archive._conditioned_single_insertion(
condition, state_descriptor
)
new_elements = new_elements.at[index].set(added_element)
index += 1
return (
(new_archive, new_elements, index),
(),
)
new_elements = jnp.ones_like(state_descriptors) * jnp.nan
# iterate over the indices
(new_archive, _, _), _ = jax.lax.scan(
iterate_fn,
(self, new_elements, 0),
{
"condition": relevant_indices,
"state_descriptor": state_descriptors,
},
)
return new_archive # type: ignore
size: float
property
readonly
¶
Compute the number of state descriptors stored in the archive.
Returns: |
|
---|
create(acceptance_threshold, state_descriptor_size, max_size)
classmethod
¶
Create an Archive instance.
This class method provides a convenient way to create the archive while keeping the init function for more general way to init an archive.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/archive.py
@classmethod
def create(
cls,
acceptance_threshold: float,
state_descriptor_size: int,
max_size: int,
) -> Archive:
"""Create an Archive instance.
This class method provides a convenient way to create the archive while
keeping the __init__ function for more general way to init an archive.
Args:
acceptance_threshold: the minimal distance to a stored descriptor to
be respected for a new descriptor to be added.
state_descriptor_size: the number of elements in a state descriptor.
max_size: the maximal size of the archive. In case of overflow, previous
elements are replaced by new ones. Defaults to 80000.
Returns:
A newly initialized archive.
"""
init_data = jnp.ones((max_size, state_descriptor_size)) * jnp.nan
return cls( # type: ignore
data=init_data,
current_position=0,
acceptance_threshold=acceptance_threshold,
state_descriptor_size=state_descriptor_size,
max_size=max_size,
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/archive.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
insert(self, state_descriptors)
¶
Tries to insert a batch of state descriptors in the archive.
- First, look at the distance of each new state descriptor with the already stored ones.
- Then, scan the state descriptors, check the distance with the descriptors inserted during the scan.
- If the state descriptor verified the first condition (not too close to a state descriptor in the old archive) and the second (not too close from a state descriptor that has just been added), then it is added to the archive.
Note 1: the archive has a fixed size, hence, in case of overflow, the first elements added are removed first (FIFO style). Note 2: keep in mind that fake descriptors are used to help keep the size constant. Those correspond to a descriptor very far away from the typical values of the problem at hand.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/archive.py
@jax.jit
def insert(self, state_descriptors: jnp.ndarray) -> Archive:
"""Tries to insert a batch of state descriptors in the archive.
1. First, look at the distance of each new state descriptor with the
already stored ones.
2. Then, scan the state descriptors, check the distance with
the descriptors inserted during the scan.
3. If the state descriptor verified the first condition (not too close
to a state descriptor in the old archive) and the second (not too close
from a state descriptor that has just been added), then it is added
to the archive.
Note 1: the archive has a fixed size, hence, in case of overflow, the
first elements added are removed first (FIFO style).
Note 2: keep in mind that fake descriptors are used to help keep the size
constant. Those correspond to a descriptor very far away from the typical
values of the problem at hand.
Args:
state_descriptors: state descriptors to be added.
Returns:
New archive updated with the state descriptors.
"""
state_descriptors = state_descriptors.reshape((-1, state_descriptors.shape[-1]))
# get nearest neigbor for each new state descriptor
values, _indices = knn(self.data, state_descriptors, 1)
# get indices where distance bigger than threshold
relevant_indices = jnp.where(
values.squeeze() > self.acceptance_threshold, x=0, y=1
)
def iterate_fn(
carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict
) -> Tuple[Tuple[Archive, jnp.ndarray, int], Any]:
"""Iterates over the archive to add elements one after the other.
Args:
carry: tuple containing the archive, the state descriptors and the
indices.
condition_data: the first addition condition of the state descriptors
given, which corresponds to being sufficiently far away from already
stored descriptors.
Returns:
The update tuple.
"""
archive, new_elements, index = carry
first_condition = condition_data["condition"]
state_descriptor = condition_data["state_descriptor"]
# do the filtering among the added elements
# get nearest neigbor for each new state descriptor
values, _indices = knn(new_elements, state_descriptor.reshape(1, -1), 1)
# get indices where distance bigger than threshold
not_too_close = jnp.where(
values.squeeze() > self.acceptance_threshold, x=0, y=1
)
second_condition = not_too_close.sum()
condition = (first_condition + second_condition) == 0
new_archive, added_element = archive._conditioned_single_insertion(
condition, state_descriptor
)
new_elements = new_elements.at[index].set(added_element)
index += 1
return (
(new_archive, new_elements, index),
(),
)
new_elements = jnp.ones_like(state_descriptors) * jnp.nan
# iterate over the indices
(new_archive, _, _), _ = jax.lax.scan(
iterate_fn,
(self, new_elements, 0),
{
"condition": relevant_indices,
"state_descriptor": state_descriptors,
},
)
return new_archive # type: ignore
score_euclidean_novelty(archive, state_descriptors, num_nearest_neighb, scaling_ratio)
¶
Scores the novelty of a jnp.ndarray with respect to the elements of an archive.
Typical use case in the construction of the diversity rewards in QDPG.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/archive.py
def score_euclidean_novelty(
archive: Archive,
state_descriptors: jnp.ndarray,
num_nearest_neighb: int,
scaling_ratio: float,
) -> jnp.ndarray:
"""Scores the novelty of a jnp.ndarray with respect to the elements of an archive.
Typical use case in the construction of the diversity rewards
in QDPG.
Args:
archive: an archive of state descriptors.
state_descriptors: state descriptors which novelty must be scored.
num_nearest_neighb: the number of nearest neighbors to be considered
when scoring.
scaling_ratio: the ratio applied to the the mean distance to obtain the
final value.
Returns:
The novelty scores of the given state descriptors.
"""
values, _indices = knn(archive.data, state_descriptors, num_nearest_neighb)
summed_distances = jnp.mean(jnp.square(values), axis=1)
return scaling_ratio * summed_distances
knn(data, new_data, k)
¶
K nearest neigbors - Brute force implementation. Using euclidean distance.
Code from https://www.kernel-operations.io/keops/_auto_benchmarks/ plot_benchmark_KNN.html
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/archive.py
@partial(jax.jit, static_argnames=("k"))
def knn(
data: jnp.ndarray, new_data: jnp.ndarray, k: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""K nearest neigbors - Brute force implementation.
Using euclidean distance.
Code from https://www.kernel-operations.io/keops/_auto_benchmarks/
plot_benchmark_KNN.html
Args:
data: given reference data.
new_data: data to be compared to the reference data.
k: number of neigbors to consider.
Returns:
The distances and indices of the nearest neighbors.
"""
# compute distances
dist = (
(new_data**2).sum(-1)[:, None]
+ (data**2).sum(-1)[None, :]
- 2 * new_data @ data.T
)
dist = jnp.nan_to_num(dist, nan=jnp.inf)
# clipping necessary - numerical approx make some distancies negative
dist = jnp.sqrt(jnp.clip(dist, a_min=0.0))
# return values, indices
values, indices = qdax_top_k(-dist, k)
return -values, indices
qdax_top_k(data, k)
¶
Returns the top k elements of an array.
Interestingly, this naive implementation is faster than the native implementation of jax for small k. See issue: https://github.com/google/jax/issues/9940
Waiting for updates in jax to change this implementation.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/archive.py
@partial(jax.jit, static_argnames=("k"))
def qdax_top_k(data: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns the top k elements of an array.
Interestingly, this naive implementation is faster than the native implementation
of jax for small k. See issue: https://github.com/google/jax/issues/9940
Waiting for updates in jax to change this implementation.
Args:
data: given data.
k: number of top elements to determine.
Returns:
The values of the elements and their indices in the array.
"""
def top_1(data: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
indice = jnp.argmax(data, axis=1)
value = jax.vmap(lambda x, y: x[y])(data, indice)
data = jax.vmap(lambda x, y: x.at[y].set(-jnp.inf))(data, indice)
return data, value, indice
def scannable_top_1(
carry: jnp.ndarray, unused: Any
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
data = carry
data, value, indice = top_1(data)
return data, (value, indice)
data, (values, indices) = jax.lax.scan(scannable_top_1, data, (), k)
return values.T, indices.T
ga_repertoire
¶
Defines a repertoire for simple genetic algorithms.
GARepertoire (Repertoire)
dataclass
¶
Class for a simple repertoire for a simple genetic algorithm.
Parameters: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
class GARepertoire(Repertoire):
"""Class for a simple repertoire for a simple genetic
algorithm.
Args:
genotypes: a PyTree containing the genotypes of the
individuals in the population. Each leaf has the
shape (population_size, num_features).
fitnesses: an array containing the fitness of the individuals
in the population. With shape (population_size, fitness_dim).
The implementation of GARepertoire was thought for the case
where fitness_dim equals 1 but the class can be herited and
rules adapted for cases where fitness_dim is greater than 1.
"""
genotypes: Genotype
fitnesses: Fitness
@property
def size(self) -> int:
"""Gives the size of the population."""
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0])
def save(self, path: str = "./") -> None:
"""Saves the repertoire.
Args:
path: place to store the files. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "scores.npy", self.fitnesses)
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
"""Loads a GA Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A GA Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
)
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample genotypes from the repertoire.
Args:
random_key: a random key to handle stochasticity.
num_samples: the number of genotypes to sample.
Returns:
The sample of genotypes.
"""
# prepare sampling probability
mask = self.fitnesses != -jnp.inf
p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1))
# sample
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(num_samples,), p=p, replace=False
),
self.genotypes,
)
return samples, random_key
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> GARepertoire:
"""Implements the repertoire addition rules.
Parents and offsprings are gathered and only the population_size
bests are kept. The others are killed.
Args:
batch_of_genotypes: new genotypes that we try to add.
batch_of_fitnesses: fitness of those new genotypes.
Returns:
The updated repertoire.
"""
# gather individuals and fitnesses
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidates_fitnesses = jnp.concatenate(
(self.fitnesses, batch_of_fitnesses), axis=0
)
# sort by fitnesses
indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]
# keep only the best ones
survivor_indices = indices[: self.size]
# keep only the best ones
new_candidates = jax.tree_util.tree_map(
lambda x: x[survivor_indices], candidates
)
new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
)
return new_repertoire # type: ignore
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
) -> GARepertoire:
"""Initializes the repertoire.
Start with default values and adds a first batch of genotypes
to the repertoire.
Args:
genotypes: first batch of genotypes
fitnesses: corresponding fitnesses
population_size: size of the population we want to evolve
Returns:
An initial repertoire.
"""
# create default fitnesses
default_fitnesses = -jnp.inf * jnp.ones(
shape=(population_size, fitnesses.shape[-1])
)
# create default genotypes
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)
# create an initial repertoire with those default values
repertoire = cls(genotypes=default_genotypes, fitnesses=default_fitnesses)
new_repertoire = repertoire.add(genotypes, fitnesses)
return new_repertoire # type: ignore
size: int
property
readonly
¶
Gives the size of the population.
save(self, path='./')
¶
Saves the repertoire.
Parameters: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
def save(self, path: str = "./") -> None:
"""Saves the repertoire.
Args:
path: place to store the files. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "scores.npy", self.fitnesses)
load(reconstruction_fn, path='./')
classmethod
¶
Loads a GA Repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
"""Loads a GA Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A GA Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
)
sample(self, random_key, num_samples)
¶
Sample genotypes from the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample genotypes from the repertoire.
Args:
random_key: a random key to handle stochasticity.
num_samples: the number of genotypes to sample.
Returns:
The sample of genotypes.
"""
# prepare sampling probability
mask = self.fitnesses != -jnp.inf
p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1))
# sample
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(num_samples,), p=p, replace=False
),
self.genotypes,
)
return samples, random_key
add(self, batch_of_genotypes, batch_of_fitnesses)
¶
Implements the repertoire addition rules.
Parents and offsprings are gathered and only the population_size bests are kept. The others are killed.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> GARepertoire:
"""Implements the repertoire addition rules.
Parents and offsprings are gathered and only the population_size
bests are kept. The others are killed.
Args:
batch_of_genotypes: new genotypes that we try to add.
batch_of_fitnesses: fitness of those new genotypes.
Returns:
The updated repertoire.
"""
# gather individuals and fitnesses
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidates_fitnesses = jnp.concatenate(
(self.fitnesses, batch_of_fitnesses), axis=0
)
# sort by fitnesses
indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]
# keep only the best ones
survivor_indices = indices[: self.size]
# keep only the best ones
new_candidates = jax.tree_util.tree_map(
lambda x: x[survivor_indices], candidates
)
new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
)
return new_repertoire # type: ignore
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/ga_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
init(genotypes, fitnesses, population_size)
classmethod
¶
Initializes the repertoire.
Start with default values and adds a first batch of genotypes to the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/ga_repertoire.py
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
) -> GARepertoire:
"""Initializes the repertoire.
Start with default values and adds a first batch of genotypes
to the repertoire.
Args:
genotypes: first batch of genotypes
fitnesses: corresponding fitnesses
population_size: size of the population we want to evolve
Returns:
An initial repertoire.
"""
# create default fitnesses
default_fitnesses = -jnp.inf * jnp.ones(
shape=(population_size, fitnesses.shape[-1])
)
# create default genotypes
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)
# create an initial repertoire with those default values
repertoire = cls(genotypes=default_genotypes, fitnesses=default_fitnesses)
new_repertoire = repertoire.add(genotypes, fitnesses)
return new_repertoire # type: ignore
mapelites_repertoire
¶
This file contains util functions and a class to define a repertoire, used to store individuals in the MAP-Elites algorithm as well as several variants.
MapElitesRepertoire (PyTreeNode)
dataclass
¶
Class for the repertoire in Map Elites.
Parameters: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
class MapElitesRepertoire(flax.struct.PyTreeNode):
"""Class for the repertoire in Map Elites.
Args:
genotypes: a PyTree containing all the genotypes in the repertoire ordered
by the centroids. Each leaf has a shape (num_centroids, num_features). The
PyTree can be a simple Jax array or a more complex nested structure such
as to represent parameters of neural network in Flax.
fitnesses: an array that contains the fitness of solutions in each cell of the
repertoire, ordered by centroids. The array shape is (num_centroids,).
descriptors: an array that contains the descriptors of solutions in each cell
of the repertoire, ordered by centroids. The array shape
is (num_centroids, num_descriptors).
centroids: an array that contains the centroids of the tessellation. The array
shape is (num_centroids, num_descriptors).
"""
genotypes: Genotype
fitnesses: Fitness
descriptors: Descriptor
centroids: Centroid
def save(self, path: str = "./") -> None:
"""Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "centroids.npy", self.centroids)
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesRepertoire:
"""Loads a MAP Elites Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A MAP Elites Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
centroids = jnp.load(path + "centroids.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
)
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""
repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
"""
Add a batch of elements to the repertoire.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes. Its shape is (batch_size, num_descriptors)
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes. Its shape is (batch_size,)
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated MAP-Elites repertoire.
"""
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
num_centroids = self.centroids.shape[0]
# get fitness segment max
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
num_segments=num_centroids,
)
cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)
# put dominated fitness to -jnp.inf
batch_of_fitnesses = jnp.where(
batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
)
# get addition condition
repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(
repertoire_fitnesses, batch_of_indices, 0
)
addition_condition = batch_of_fitnesses > current_fitnesses
# assign fake position when relevant : num_centroids is out of bound
batch_of_indices = jnp.where(
addition_condition, x=batch_of_indices, y=num_centroids
)
# create new repertoire
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_fitnesses.squeeze(axis=-1)
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_descriptors
)
return MapElitesRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
descriptors=new_descriptors,
centroids=self.centroids,
)
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
extra_scores: unused extra_scores of the initial genotypes
Returns:
an initialized MAP-Elite repertoire
"""
warnings.warn(
(
"This type of repertoire does not store the extra scores "
"computed by the scoring function"
),
stacklevel=2,
)
# retrieve one genotype from the population
first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes)
# create a repertoire with default values
repertoire = cls.init_default(genotype=first_genotype, centroids=centroids)
# add initial population to the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
return new_repertoire # type: ignore
@classmethod
def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
) -> MapElitesRepertoire:
"""Initialize a Map-Elites repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed
with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so
it can be called easily called from other modules.
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire
Returns:
A repertoire filled with default values.
"""
# get number of centroids
num_centroids = centroids.shape[0]
# default fitness is -inf
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# default genotypes is all 0
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
genotype,
)
# default descriptor is all zeros
default_descriptors = jnp.zeros_like(centroids)
return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/mapelites_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
save(self, path='./')
¶
Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.
Parameters: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
def save(self, path: str = "./") -> None:
"""Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "centroids.npy", self.centroids)
load(reconstruction_fn, path='./')
classmethod
¶
Loads a MAP Elites Repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesRepertoire:
"""Loads a MAP Elites Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A MAP Elites Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
centroids = jnp.load(path + "centroids.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
)
sample(self, random_key, num_samples)
¶
Sample elements in the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""
repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)
¶
Add a batch of elements to the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
"""
Add a batch of elements to the repertoire.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes. Its shape is (batch_size, num_descriptors)
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes. Its shape is (batch_size,)
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated MAP-Elites repertoire.
"""
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
num_centroids = self.centroids.shape[0]
# get fitness segment max
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
num_segments=num_centroids,
)
cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)
# put dominated fitness to -jnp.inf
batch_of_fitnesses = jnp.where(
batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
)
# get addition condition
repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(
repertoire_fitnesses, batch_of_indices, 0
)
addition_condition = batch_of_fitnesses > current_fitnesses
# assign fake position when relevant : num_centroids is out of bound
batch_of_indices = jnp.where(
addition_condition, x=batch_of_indices, y=num_centroids
)
# create new repertoire
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_fitnesses.squeeze(axis=-1)
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_descriptors
)
return MapElitesRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
descriptors=new_descriptors,
centroids=self.centroids,
)
init(genotypes, fitnesses, descriptors, centroids, extra_scores=None)
classmethod
¶
Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
extra_scores: unused extra_scores of the initial genotypes
Returns:
an initialized MAP-Elite repertoire
"""
warnings.warn(
(
"This type of repertoire does not store the extra scores "
"computed by the scoring function"
),
stacklevel=2,
)
# retrieve one genotype from the population
first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes)
# create a repertoire with default values
repertoire = cls.init_default(genotype=first_genotype, centroids=centroids)
# add initial population to the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
return new_repertoire # type: ignore
init_default(genotype, centroids)
classmethod
¶
Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
) -> MapElitesRepertoire:
"""Initialize a Map-Elites repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed
with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so
it can be called easily called from other modules.
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire
Returns:
A repertoire filled with default values.
"""
# get number of centroids
num_centroids = centroids.shape[0]
# default fitness is -inf
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# default genotypes is all 0
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
genotype,
)
# default descriptor is all zeros
default_descriptors = jnp.zeros_like(centroids)
return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
compute_cvt_centroids(num_descriptors, num_init_cvt_samples, num_centroids, minval, maxval, random_key)
¶
Compute centroids for CVT tessellation.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
def compute_cvt_centroids(
num_descriptors: int,
num_init_cvt_samples: int,
num_centroids: int,
minval: Union[float, List[float]],
maxval: Union[float, List[float]],
random_key: RNGKey,
) -> Tuple[jnp.ndarray, RNGKey]:
"""Compute centroids for CVT tessellation.
Args:
num_descriptors: number of scalar descriptors
num_init_cvt_samples: number of sampled point to be sued for clustering to
determine the centroids. The larger the number of centroids and the
number of descriptors, the higher this value must be (e.g. 100000 for
1024 centroids and 100 descriptors).
num_centroids: number of centroids
minval: minimum descriptors value
maxval: maximum descriptors value
random_key: a jax PRNG random key
Returns:
the centroids with shape (num_centroids, num_descriptors)
random_key: an updated jax PRNG random key
"""
minval = jnp.array(minval)
maxval = jnp.array(maxval)
# assume here all values are in [0, 1] and rescale later
random_key, subkey = jax.random.split(random_key)
x = jax.random.uniform(key=subkey, shape=(num_init_cvt_samples, num_descriptors))
# compute k means
random_key, subkey = jax.random.split(random_key)
k_means = KMeans(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
random_state=RandomState(subkey),
)
k_means.fit(x)
centroids = k_means.cluster_centers_
# rescale now
return jnp.asarray(centroids) * (maxval - minval) + minval, random_key
compute_euclidean_centroids(grid_shape, minval, maxval)
¶
Compute centroids for square Euclidean tessellation.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
def compute_euclidean_centroids(
grid_shape: Tuple[int, ...],
minval: Union[float, List[float]],
maxval: Union[float, List[float]],
) -> jnp.ndarray:
"""Compute centroids for square Euclidean tessellation.
Args:
grid_shape: number of centroids per BD dimension
minval: minimum descriptors value
maxval: maximum descriptors value
Returns:
the centroids with shape (num_centroids, num_descriptors)
"""
# get number of descriptors
num_descriptors = len(grid_shape)
# prepare list of linspaces
linspace_list = []
for num_centroids_in_dim in grid_shape:
offset = 1 / (2 * num_centroids_in_dim)
linspace = jnp.linspace(offset, 1.0 - offset, num_centroids_in_dim)
linspace_list.append(linspace)
meshes = jnp.meshgrid(*linspace_list, sparse=False)
# create centroids
centroids = jnp.stack(
[jnp.ravel(meshes[i]) for i in range(num_descriptors)], axis=-1
)
minval = jnp.array(minval)
maxval = jnp.array(maxval)
return jnp.asarray(centroids) * (maxval - minval) + minval
get_cells_indices(batch_of_descriptors, centroids)
¶
Returns the array of cells indices for a batch of descriptors given the centroids of the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mapelites_repertoire.py
def get_cells_indices(
batch_of_descriptors: jnp.ndarray, centroids: jnp.ndarray
) -> jnp.ndarray:
"""
Returns the array of cells indices for a batch of descriptors
given the centroids of the repertoire.
Args:
batch_of_descriptors: a batch of descriptors
of shape (batch_size, num_descriptors)
centroids: centroids array of shape (num_centroids, num_descriptors)
Returns:
the indices of the centroids corresponding to each vector of descriptors
in the batch with shape (batch_size,)
"""
def _get_cells_indices(
descriptors: jnp.ndarray, centroids: jnp.ndarray
) -> jnp.ndarray:
"""Set_of_descriptors of shape (1, num_descriptors)
centroids of shape (num_centroids, num_descriptors)
"""
return jnp.argmin(
jnp.sum(jnp.square(jnp.subtract(descriptors, centroids)), axis=-1)
)
func = jax.vmap(lambda x: _get_cells_indices(x, centroids))
return func(batch_of_descriptors)
mels_repertoire
¶
This file contains the class to define the repertoire used to store individuals in the Multi-Objective MAP-Elites algorithm as well as several variants.
MELSRepertoire (MapElitesRepertoire)
dataclass
¶
Class for the repertoire in MAP-Elites Low-Spread.
This class inherits from MapElitesRepertoire. In addition to the stored data in MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire also maintains an array of spreads. We overload the save, load, add, and init_default methods of MapElitesRepertoire.
Refer to Mace 2023 for more info on MAP-Elites Low-Spread: https://dl.acm.org/doi/abs/10.1145/3583131.3590433
Parameters: |
|
---|
Source code in qdax/core/containers/mels_repertoire.py
class MELSRepertoire(MapElitesRepertoire):
"""Class for the repertoire in MAP-Elites Low-Spread.
This class inherits from MapElitesRepertoire. In addition to the stored data in
MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire
also maintains an array of spreads. We overload the save, load, add, and
init_default methods of MapElitesRepertoire.
Refer to Mace 2023 for more info on MAP-Elites Low-Spread:
https://dl.acm.org/doi/abs/10.1145/3583131.3590433
Args:
genotypes: a PyTree containing all the genotypes in the repertoire ordered
by the centroids. Each leaf has a shape (num_centroids, num_features). The
PyTree can be a simple Jax array or a more complex nested structure such
as to represent parameters of neural network in Flax.
fitnesses: an array that contains the fitness of solutions in each cell of the
repertoire, ordered by centroids. The array shape is (num_centroids,).
descriptors: an array that contains the descriptors of solutions in each cell
of the repertoire, ordered by centroids. The array shape
is (num_centroids, num_descriptors).
centroids: an array that contains the centroids of the tessellation. The array
shape is (num_centroids, num_descriptors).
spreads: an array that contains the spread of solutions in each cell of the
repertoire, ordered by centroids. The array shape is (num_centroids,).
"""
spreads: Spread
def save(self, path: str = "./") -> None:
"""Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "centroids.npy", self.centroids)
jnp.save(path + "spreads.npy", self.spreads)
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire:
"""Loads a MAP-Elites Low-Spread Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A MAP-Elites Low-Spread Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
centroids = jnp.load(path + "centroids.npy")
spreads = jnp.load(path + "spreads.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
spreads=spreads,
)
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MELSRepertoire:
"""
Add a batch of elements to the repertoire.
The key difference between this method and the default add() in
MapElitesRepertoire is that it expects each individual to be evaluated
`num_samples` times, resulting in `num_samples` fitnesses and
`num_samples` descriptors per individual.
If multiple individuals may be added to a single cell, this method will
arbitrarily pick one -- the exact choice depends on the implementation of
jax.at[].set(), which can be non-deterministic:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
We do not currently check if one of the multiple individuals dominates the
others (dominate means that the individual has both highest fitness and lowest
spread among the individuals for that cell).
If `num_samples` is only 1, the spreads will default to 0.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes over all evals. Its shape is
(batch_size, num_samples, num_descriptors). Note that we "aggregate"
descriptors by finding the most frequent cell of each individual. Thus,
the actual descriptors stored in the repertoire are just the coordinates
of the centroid of the most frequent cell.
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes over all evals. Its shape is (batch_size,
num_samples)
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated repertoire.
"""
batch_size, num_samples = batch_of_fitnesses.shape
# Compute indices/cells of all descriptors.
batch_of_all_indices = get_cells_indices(
batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
).reshape((batch_size, num_samples))
# Compute most frequent cell of each solution.
batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]
# Compute dispersion / spread. The dispersion is set to zero if
# num_samples is 1.
batch_of_spreads = jax.lax.cond(
num_samples == 1,
lambda desc: jnp.zeros(batch_size),
lambda desc: jax.vmap(_dispersion)(
desc.reshape((batch_size, num_samples, -1))
),
batch_of_descriptors,
)
batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)
# Compute canonical descriptors as the descriptor of the centroid of the most
# frequent cell. Note that this line redefines the earlier batch_of_descriptors.
batch_of_descriptors = jnp.take_along_axis(
self.centroids, batch_of_indices, axis=0
)
# Compute canonical fitnesses as the average fitness.
#
# Shape: (batch_size, 1)
batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)
num_centroids = self.centroids.shape[0]
# get current repertoire fitnesses and spreads
repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(
repertoire_fitnesses, batch_of_indices, 0
)
repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)
# get addition condition
addition_condition_fitness = batch_of_fitnesses > current_fitnesses
addition_condition_spread = batch_of_spreads <= current_spreads
addition_condition = jnp.logical_and(
addition_condition_fitness, addition_condition_spread
)
# assign fake position when relevant : num_centroids is out of bound
batch_of_indices = jnp.where(
addition_condition, x=batch_of_indices, y=num_centroids
)
# create new repertoire
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_fitnesses.squeeze(axis=-1)
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_descriptors
)
new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_spreads.squeeze(axis=-1)
)
return MELSRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
descriptors=new_descriptors,
centroids=self.centroids,
spreads=new_spreads,
)
@classmethod
def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
) -> MELSRepertoire:
"""Initialize a MAP-Elites Low-Spread repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed with any
method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MELS, so
it can be called easily called from other modules.
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire.
Returns:
A repertoire filled with default values.
"""
# get number of centroids
num_centroids = centroids.shape[0]
# default fitness is -inf
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# default genotypes is all 0
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
genotype,
)
# default descriptor is all zeros
default_descriptors = jnp.zeros_like(centroids)
# default spread is inf so that any spread will be less
default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)
return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
spreads=default_spreads,
)
save(self, path='./')
¶
Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.
Parameters: |
|
---|
Source code in qdax/core/containers/mels_repertoire.py
def save(self, path: str = "./") -> None:
"""Saves the repertoire on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "centroids.npy", self.centroids)
jnp.save(path + "spreads.npy", self.spreads)
load(reconstruction_fn, path='./')
classmethod
¶
Loads a MAP-Elites Low-Spread Repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mels_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire:
"""Loads a MAP-Elites Low-Spread Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A MAP-Elites Low-Spread Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
centroids = jnp.load(path + "centroids.npy")
spreads = jnp.load(path + "spreads.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
spreads=spreads,
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/mels_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)
¶
Add a batch of elements to the repertoire.
The key difference between this method and the default add() in
MapElitesRepertoire is that it expects each individual to be evaluated
num_samples
times, resulting in num_samples
fitnesses and
num_samples
descriptors per individual.
If multiple individuals may be added to a single cell, this method will arbitrarily pick one -- the exact choice depends on the implementation of jax.at[].set(), which can be non-deterministic: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html We do not currently check if one of the multiple individuals dominates the others (dominate means that the individual has both highest fitness and lowest spread among the individuals for that cell).
If num_samples
is only 1, the spreads will default to 0.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mels_repertoire.py
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MELSRepertoire:
"""
Add a batch of elements to the repertoire.
The key difference between this method and the default add() in
MapElitesRepertoire is that it expects each individual to be evaluated
`num_samples` times, resulting in `num_samples` fitnesses and
`num_samples` descriptors per individual.
If multiple individuals may be added to a single cell, this method will
arbitrarily pick one -- the exact choice depends on the implementation of
jax.at[].set(), which can be non-deterministic:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
We do not currently check if one of the multiple individuals dominates the
others (dominate means that the individual has both highest fitness and lowest
spread among the individuals for that cell).
If `num_samples` is only 1, the spreads will default to 0.
Args:
batch_of_genotypes: a batch of genotypes to be added to the repertoire.
Similarly to the self.genotypes argument, this is a PyTree in which
the leaves have a shape (batch_size, num_features)
batch_of_descriptors: an array that contains the descriptors of the
aforementioned genotypes over all evals. Its shape is
(batch_size, num_samples, num_descriptors). Note that we "aggregate"
descriptors by finding the most frequent cell of each individual. Thus,
the actual descriptors stored in the repertoire are just the coordinates
of the centroid of the most frequent cell.
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes over all evals. Its shape is (batch_size,
num_samples)
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated repertoire.
"""
batch_size, num_samples = batch_of_fitnesses.shape
# Compute indices/cells of all descriptors.
batch_of_all_indices = get_cells_indices(
batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
).reshape((batch_size, num_samples))
# Compute most frequent cell of each solution.
batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]
# Compute dispersion / spread. The dispersion is set to zero if
# num_samples is 1.
batch_of_spreads = jax.lax.cond(
num_samples == 1,
lambda desc: jnp.zeros(batch_size),
lambda desc: jax.vmap(_dispersion)(
desc.reshape((batch_size, num_samples, -1))
),
batch_of_descriptors,
)
batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)
# Compute canonical descriptors as the descriptor of the centroid of the most
# frequent cell. Note that this line redefines the earlier batch_of_descriptors.
batch_of_descriptors = jnp.take_along_axis(
self.centroids, batch_of_indices, axis=0
)
# Compute canonical fitnesses as the average fitness.
#
# Shape: (batch_size, 1)
batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)
num_centroids = self.centroids.shape[0]
# get current repertoire fitnesses and spreads
repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(
repertoire_fitnesses, batch_of_indices, 0
)
repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)
# get addition condition
addition_condition_fitness = batch_of_fitnesses > current_fitnesses
addition_condition_spread = batch_of_spreads <= current_spreads
addition_condition = jnp.logical_and(
addition_condition_fitness, addition_condition_spread
)
# assign fake position when relevant : num_centroids is out of bound
batch_of_indices = jnp.where(
addition_condition, x=batch_of_indices, y=num_centroids
)
# create new repertoire
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_fitnesses.squeeze(axis=-1)
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_descriptors
)
new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
batch_of_spreads.squeeze(axis=-1)
)
return MELSRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
descriptors=new_descriptors,
centroids=self.centroids,
spreads=new_spreads,
)
init_default(genotype, centroids)
classmethod
¶
Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MELS, so it can be called easily called from other modules.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mels_repertoire.py
@classmethod
def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
) -> MELSRepertoire:
"""Initialize a MAP-Elites Low-Spread repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed with any
method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MELS, so
it can be called easily called from other modules.
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire.
Returns:
A repertoire filled with default values.
"""
# get number of centroids
num_centroids = centroids.shape[0]
# default fitness is -inf
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
# default genotypes is all 0
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
genotype,
)
# default descriptor is all zeros
default_descriptors = jnp.zeros_like(centroids)
# default spread is inf so that any spread will be less
default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)
return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
spreads=default_spreads,
)
mome_repertoire
¶
This file contains the class to define the repertoire used to store individuals in the Multi-Objective MAP-Elites algorithm as well as several variants.
MOMERepertoire (MapElitesRepertoire)
dataclass
¶
Class for the repertoire in Multi Objective Map Elites
This class inherits from MAPElitesRepertoire. The stored data is the same: genotypes, fitnesses, descriptors, centroids.
The shape of genotypes is (in the case where it's an array): (num_centroids, pareto_front_length, genotype_dim). When the genotypes is a PyTree, the two first dimensions are the same but the third will depend on the leafs.
The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)
The shape of descriptors and centroids are: (num_centroids, num_descriptors, pareto_front_length).
Inherited functions: save and load.
Source code in qdax/core/containers/mome_repertoire.py
class MOMERepertoire(MapElitesRepertoire):
"""Class for the repertoire in Multi Objective Map Elites
This class inherits from MAPElitesRepertoire. The stored data
is the same: genotypes, fitnesses, descriptors, centroids.
The shape of genotypes is (in the case where it's an array):
(num_centroids, pareto_front_length, genotype_dim).
When the genotypes is a PyTree, the two first dimensions are the same
but the third will depend on the leafs.
The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)
The shape of descriptors and centroids are:
(num_centroids, num_descriptors, pareto_front_length).
Inherited functions: save and load.
"""
@property
def repertoire_capacity(self) -> int:
"""Returns the maximum number of solutions the repertoire can
contain which corresponds to the number of cells times the
maximum pareto front length.
Returns:
The repertoire capacity.
"""
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0] * first_leaf.shape[1])
@jax.jit
def _sample_in_masked_pareto_front(
self,
pareto_front_genotypes: ParetoFront[Genotype],
mask: Mask,
random_key: RNGKey,
) -> Genotype:
"""Sample one single genotype in masked pareto front.
Note: do not retrieve a random key because this function
is to be vmapped. The public method that uses this function
will return a random key
Args:
pareto_front_genotypes: the genotypes of a pareto front
mask: a mask associated to the front
random_key: a random key to handle stochastic operations
Returns:
A single genotype among the pareto front.
"""
p = (1.0 - mask) / jnp.sum(1.0 - mask)
genotype_sample = jax.tree_util.tree_map(
lambda x: jax.random.choice(random_key, x, shape=(1,), p=p),
pareto_front_genotypes,
)
return genotype_sample
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
This method sample a non-empty pareto front, and then sample
genotypes from this pareto front.
Args:
random_key: a random key to handle stochasticity.
num_samples: number of samples to retrieve from the repertoire.
Returns:
A sample of genotypes and a new random key.
"""
# create sampling probability for the cells
repertoire_empty = jnp.any(self.fitnesses == -jnp.inf, axis=-1)
occupied_cells = jnp.any(~repertoire_empty, axis=-1)
p = occupied_cells / jnp.sum(occupied_cells)
# possible indices - num cells
indices = jnp.arange(start=0, stop=repertoire_empty.shape[0])
# choose idx - among indices of cells that are not empty
random_key, subkey = jax.random.split(random_key)
cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)
# get genotypes (front) from the chosen indices
pareto_front_genotypes = jax.tree_util.tree_map(
lambda x: x[cells_idx], self.genotypes
)
# prepare second sampling function
sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)
# sample genotypes from the pareto front
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=num_samples)
sampled_genotypes = sample_in_fronts( # type: ignore
pareto_front_genotypes=pareto_front_genotypes,
mask=repertoire_empty[cells_idx],
random_key=subkeys,
)
# remove the dim coming from pareto front
sampled_genotypes = jax.tree_util.tree_map(
lambda x: x.squeeze(axis=1), sampled_genotypes
)
return sampled_genotypes, random_key
@jax.jit
def _update_masked_pareto_front(
self,
pareto_front_fitnesses: ParetoFront[Fitness],
pareto_front_genotypes: ParetoFront[Genotype],
pareto_front_descriptors: ParetoFront[Descriptor],
mask: Mask,
new_batch_of_fitnesses: Fitness,
new_batch_of_genotypes: Genotype,
new_batch_of_descriptors: Descriptor,
new_mask: Mask,
) -> Tuple[
ParetoFront[Fitness], ParetoFront[Genotype], ParetoFront[Descriptor], Mask
]:
"""Takes a fixed size pareto front, its mask and new points to add.
Returns updated front and mask.
Args:
pareto_front_fitnesses: fitness of the pareto front
pareto_front_genotypes: corresponding genotypes
pareto_front_descriptors: corresponding descriptors
mask: mask of the front, to hide void parts
new_batch_of_fitnesses: new batch of fitness that is considered
to be added to the pareto front
new_batch_of_genotypes: corresponding genotypes
new_batch_of_descriptors: corresponding descriptors
new_mask: corresponding mask (no one is masked)
Returns:
The updated pareto front.
"""
# get dimensions
batch_size = new_batch_of_fitnesses.shape[0]
num_criteria = new_batch_of_fitnesses.shape[1]
pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore
first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0]
genotypes_dim = first_leaf.shape[1]
descriptors_dim = new_batch_of_descriptors.shape[1]
# gather all data
cat_mask = jnp.concatenate([mask, new_mask], axis=-1)
cat_fitnesses = jnp.concatenate(
[pareto_front_fitnesses, new_batch_of_fitnesses], axis=0
)
cat_genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
pareto_front_genotypes,
new_batch_of_genotypes,
)
cat_descriptors = jnp.concatenate(
[pareto_front_descriptors, new_batch_of_descriptors], axis=0
)
# get new front
cat_bool_front = compute_masked_pareto_front(
batch_of_criteria=cat_fitnesses, mask=cat_mask
)
# get corresponding indices
indices = (
jnp.arange(start=0, stop=pareto_front_len + batch_size) * cat_bool_front
)
indices = indices + ~cat_bool_front * (batch_size + pareto_front_len - 1)
indices = jnp.sort(indices)
# get new fitness, genotypes and descriptors
new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0)
new_front_genotypes = jax.tree_util.tree_map(
lambda x: jnp.take(x, indices, axis=0), cat_genotypes
)
new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0)
# compute new mask
num_front_elements = jnp.sum(cat_bool_front)
new_mask_indices = jnp.arange(start=0, stop=batch_size + pareto_front_len)
new_mask_indices = (num_front_elements - new_mask_indices) > 0
new_mask = jnp.where(
new_mask_indices,
jnp.ones(shape=batch_size + pareto_front_len, dtype=bool),
jnp.zeros(shape=batch_size + pareto_front_len, dtype=bool),
)
fitness_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), num_criteria, axis=-1
)
new_front_fitness = new_front_fitness * fitness_mask
front_size = len(pareto_front_fitnesses) # type: ignore
new_front_fitness = new_front_fitness[:front_size, :]
genotypes_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1
)
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x * genotypes_mask, new_front_genotypes
)
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x[:front_size, :], new_front_genotypes
)
descriptors_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), descriptors_dim, axis=-1
)
new_front_descriptors = new_front_descriptors * descriptors_mask
new_front_descriptors = new_front_descriptors[:front_size, :]
new_mask = ~new_mask[:front_size]
return new_front_fitness, new_front_genotypes, new_front_descriptors, new_mask
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""Insert a batch of elements in the repertoire.
Shape of the batch_of_genotypes (if an array):
(batch_size, genotypes_dim)
Shape of the batch_of_descriptors: (batch_size, num_descriptors)
Shape of the batch_of_fitnesses: (batch_size, num_criteria)
Args:
batch_of_genotypes: a batch of genotypes that we are trying to
insert into the repertoire.
batch_of_descriptors: the descriptors of the genotypes we are
trying to add to the repertoire.
batch_of_fitnesses: the fitnesses of the genotypes we are trying
to add to the repertoire.
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated repertoire with potential new individuals.
"""
# get the indices that corresponds to the descriptors in the repertoire
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
def _add_one(
carry: MOMERepertoire,
data: Tuple[Genotype, Descriptor, Fitness, jnp.ndarray],
) -> Tuple[MOMERepertoire, Any]:
# unwrap data
genotype, descriptors, fitness, index = data
index = index.astype(jnp.int32)
# get cell data
cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
cell_fitness = carry.fitnesses[index]
cell_descriptor = carry.descriptors[index]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)
# update pareto front
(
cell_fitness,
cell_genotype,
cell_descriptor,
cell_mask,
) = self._update_masked_pareto_front(
pareto_front_fitnesses=cell_fitness.squeeze(axis=0),
pareto_front_genotypes=cell_genotype.squeeze(axis=0),
pareto_front_descriptors=cell_descriptor.squeeze(axis=0),
mask=cell_mask.squeeze(axis=0),
new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
new_mask=jnp.zeros(shape=(1,), dtype=bool),
)
# update cell fitness
cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)
# update grid
new_genotypes = jax.tree_util.tree_map(
lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
)
new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
carry = carry.replace( # type: ignore
genotypes=new_genotypes,
descriptors=new_descriptors,
fitnesses=new_fitnesses,
)
# return new grid
return carry, ()
# scan the addition operation for all the data
self, _ = jax.lax.scan(
_add_one,
self,
(
batch_of_genotypes,
batch_of_descriptors,
batch_of_fitnesses,
batch_of_indices,
),
)
return self
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
pareto_front_max_length: int,
extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""
Initialize a Multi Objective Map-Elites repertoire with an initial population
of genotypes. Requires the definition of centroids that can be computed with
any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape:
(batch_size, num_criteria)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
pareto_front_max_length: maximum size of the pareto fronts
extra_scores: unused extra_scores of the initial genotypes
Returns:
An initialized MAP-Elite repertoire
"""
warnings.warn(
(
"This type of repertoire does not store the extra scores "
"computed by the scoring function"
),
stacklevel=2,
)
# get dimensions
num_criteria = fitnesses.shape[1]
num_descriptors = descriptors.shape[1]
num_centroids = centroids.shape[0]
# create default values
default_fitnesses = -jnp.inf * jnp.ones(
shape=(num_centroids, pareto_front_max_length, num_criteria)
)
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(
shape=(
num_centroids,
pareto_front_max_length,
)
+ x.shape[1:]
),
genotypes,
)
default_descriptors = jnp.zeros(
shape=(num_centroids, pareto_front_max_length, num_descriptors)
)
# create repertoire with default values
repertoire = MOMERepertoire( # type: ignore
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
# add first batch of individuals in the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
return new_repertoire # type: ignore
@jax.jit
def compute_global_pareto_front(
self,
) -> Tuple[ParetoFront[Fitness], Mask]:
"""Merge all the pareto fronts of the MOME repertoire into a single one
called global pareto front.
Returns:
The pareto front and its mask.
"""
fitnesses = jnp.concatenate(self.fitnesses, axis=0)
mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
pareto_mask = compute_masked_pareto_front(fitnesses, mask)
pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)
return pareto_front, pareto_mask
repertoire_capacity: int
property
readonly
¶
Returns the maximum number of solutions the repertoire can contain which corresponds to the number of cells times the maximum pareto front length.
Returns: |
|
---|
sample(self, random_key, num_samples)
¶
Sample elements in the repertoire.
This method sample a non-empty pareto front, and then sample genotypes from this pareto front.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mome_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
This method sample a non-empty pareto front, and then sample
genotypes from this pareto front.
Args:
random_key: a random key to handle stochasticity.
num_samples: number of samples to retrieve from the repertoire.
Returns:
A sample of genotypes and a new random key.
"""
# create sampling probability for the cells
repertoire_empty = jnp.any(self.fitnesses == -jnp.inf, axis=-1)
occupied_cells = jnp.any(~repertoire_empty, axis=-1)
p = occupied_cells / jnp.sum(occupied_cells)
# possible indices - num cells
indices = jnp.arange(start=0, stop=repertoire_empty.shape[0])
# choose idx - among indices of cells that are not empty
random_key, subkey = jax.random.split(random_key)
cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)
# get genotypes (front) from the chosen indices
pareto_front_genotypes = jax.tree_util.tree_map(
lambda x: x[cells_idx], self.genotypes
)
# prepare second sampling function
sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)
# sample genotypes from the pareto front
random_key, subkey = jax.random.split(random_key)
subkeys = jax.random.split(subkey, num=num_samples)
sampled_genotypes = sample_in_fronts( # type: ignore
pareto_front_genotypes=pareto_front_genotypes,
mask=repertoire_empty[cells_idx],
random_key=subkeys,
)
# remove the dim coming from pareto front
sampled_genotypes = jax.tree_util.tree_map(
lambda x: x.squeeze(axis=1), sampled_genotypes
)
return sampled_genotypes, random_key
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/mome_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)
¶
Insert a batch of elements in the repertoire.
Shape of the batch_of_genotypes (if an array): (batch_size, genotypes_dim) Shape of the batch_of_descriptors: (batch_size, num_descriptors) Shape of the batch_of_fitnesses: (batch_size, num_criteria)
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mome_repertoire.py
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""Insert a batch of elements in the repertoire.
Shape of the batch_of_genotypes (if an array):
(batch_size, genotypes_dim)
Shape of the batch_of_descriptors: (batch_size, num_descriptors)
Shape of the batch_of_fitnesses: (batch_size, num_criteria)
Args:
batch_of_genotypes: a batch of genotypes that we are trying to
insert into the repertoire.
batch_of_descriptors: the descriptors of the genotypes we are
trying to add to the repertoire.
batch_of_fitnesses: the fitnesses of the genotypes we are trying
to add to the repertoire.
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated repertoire with potential new individuals.
"""
# get the indices that corresponds to the descriptors in the repertoire
batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
def _add_one(
carry: MOMERepertoire,
data: Tuple[Genotype, Descriptor, Fitness, jnp.ndarray],
) -> Tuple[MOMERepertoire, Any]:
# unwrap data
genotype, descriptors, fitness, index = data
index = index.astype(jnp.int32)
# get cell data
cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
cell_fitness = carry.fitnesses[index]
cell_descriptor = carry.descriptors[index]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)
# update pareto front
(
cell_fitness,
cell_genotype,
cell_descriptor,
cell_mask,
) = self._update_masked_pareto_front(
pareto_front_fitnesses=cell_fitness.squeeze(axis=0),
pareto_front_genotypes=cell_genotype.squeeze(axis=0),
pareto_front_descriptors=cell_descriptor.squeeze(axis=0),
mask=cell_mask.squeeze(axis=0),
new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
new_mask=jnp.zeros(shape=(1,), dtype=bool),
)
# update cell fitness
cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)
# update grid
new_genotypes = jax.tree_util.tree_map(
lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
)
new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
carry = carry.replace( # type: ignore
genotypes=new_genotypes,
descriptors=new_descriptors,
fitnesses=new_fitnesses,
)
# return new grid
return carry, ()
# scan the addition operation for all the data
self, _ = jax.lax.scan(
_add_one,
self,
(
batch_of_genotypes,
batch_of_descriptors,
batch_of_fitnesses,
batch_of_indices,
),
)
return self
init(genotypes, fitnesses, descriptors, centroids, pareto_front_max_length, extra_scores=None)
classmethod
¶
Initialize a Multi Objective Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/mome_repertoire.py
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
pareto_front_max_length: int,
extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""
Initialize a Multi Objective Map-Elites repertoire with an initial population
of genotypes. Requires the definition of centroids that can be computed with
any method such as CVT or Euclidean mapping.
Note: this function has been kept outside of the object MapElites, so it can
be called easily called from other modules.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape:
(batch_size, num_criteria)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
pareto_front_max_length: maximum size of the pareto fronts
extra_scores: unused extra_scores of the initial genotypes
Returns:
An initialized MAP-Elite repertoire
"""
warnings.warn(
(
"This type of repertoire does not store the extra scores "
"computed by the scoring function"
),
stacklevel=2,
)
# get dimensions
num_criteria = fitnesses.shape[1]
num_descriptors = descriptors.shape[1]
num_centroids = centroids.shape[0]
# create default values
default_fitnesses = -jnp.inf * jnp.ones(
shape=(num_centroids, pareto_front_max_length, num_criteria)
)
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(
shape=(
num_centroids,
pareto_front_max_length,
)
+ x.shape[1:]
),
genotypes,
)
default_descriptors = jnp.zeros(
shape=(num_centroids, pareto_front_max_length, num_descriptors)
)
# create repertoire with default values
repertoire = MOMERepertoire( # type: ignore
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
)
# add first batch of individuals in the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
return new_repertoire # type: ignore
compute_global_pareto_front(self)
¶
Merge all the pareto fronts of the MOME repertoire into a single one called global pareto front.
Returns: |
|
---|
Source code in qdax/core/containers/mome_repertoire.py
@jax.jit
def compute_global_pareto_front(
self,
) -> Tuple[ParetoFront[Fitness], Mask]:
"""Merge all the pareto fronts of the MOME repertoire into a single one
called global pareto front.
Returns:
The pareto front and its mask.
"""
fitnesses = jnp.concatenate(self.fitnesses, axis=0)
mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
pareto_mask = compute_masked_pareto_front(fitnesses, mask)
pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)
return pareto_front, pareto_mask
nsga2_repertoire
¶
NSGA2Repertoire (GARepertoire)
dataclass
¶
Repertoire used for the NSGA2 algorithm.
Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, save, sample and init.
Source code in qdax/core/containers/nsga2_repertoire.py
class NSGA2Repertoire(GARepertoire):
"""Repertoire used for the NSGA2 algorithm.
Inherits from the GARepertoire. The data stored are the genotypes
and there fitness. Several functions are inherited from GARepertoire,
including size, save, sample and init.
"""
@jax.jit
def _compute_crowding_distances(
self, fitnesses: Fitness, mask: jnp.ndarray
) -> jnp.ndarray:
"""Compute crowding distances.
The crowding distance is the Manhatten Distance in the objective
space. This is used to rank individuals in the addition function.
Args:
fitnesses: fitnesses of the considered individuals. Here,
fitness are vectors as we are doing multi-objective
optimization.
mask: a vector to mask values.
Returns:
The crowding distances.
"""
# Retrieve only non masked solutions
num_solutions = fitnesses.shape[0]
num_objective = fitnesses.shape[1]
if num_solutions <= 2:
return jnp.array([jnp.inf] * num_solutions)
else:
# Sort solutions on each objective
mask_dist = jnp.column_stack([mask] * fitnesses.shape[1])
score_amplitude = jnp.max(fitnesses, axis=0) - jnp.min(fitnesses, axis=0)
dist_fitnesses = (
fitnesses + 3 * score_amplitude * jnp.ones_like(fitnesses) * mask_dist
)
sorted_index = jnp.argsort(dist_fitnesses, axis=0)
srt_fitnesses = fitnesses[sorted_index, jnp.arange(num_objective)]
# Calculate the norm for each objective - set to NaN if all values are equal
norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0)
# get the distances
dists = jnp.row_stack(
[srt_fitnesses, jnp.full(num_objective, jnp.inf)]
) - jnp.row_stack([jnp.full(num_objective, -jnp.inf), srt_fitnesses])
# Prepare the distance to last and next vectors
dist_to_last, dist_to_next = dists, dists
dist_to_last = dists[:-1] / norm
dist_to_next = dists[1:] / norm
# Sum up the distances and reorder
j = jnp.argsort(sorted_index, axis=0)
crowding_distances = (
jnp.sum(
(
dist_to_last[j, jnp.arange(num_objective)]
+ dist_to_next[j, jnp.arange(num_objective)]
),
axis=1,
)
/ num_objective
)
return crowding_distances
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> NSGA2Repertoire:
"""Implements the repertoire addition rules.
The population is sorted in successive pareto front. The first one
is the global pareto front. The second one is the pareto front of the
population where the first pareto front has been removed, etc...
The successive pareto fronts are kept until the moment where adding a
full pareto front would exceed the population size.
To decide the survival of this pareto front, a crowding distance is
computed in order to keep individuals that are spread in this last pareto
front. Hence, the individuals with the biggest crowding distances are
added until the population size is reached.
Args:
batch_of_genotypes: new genotypes that we try to add.
batch_of_fitnesses: fitness of those new genotypes.
Returns:
The updated repertoire.
"""
# All the candidates
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))
first_leaf = jax.tree_util.tree_leaves(candidates)[0]
num_candidates = first_leaf.shape[0]
def compute_current_front(
val: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Body function for the while loop. Computes the successive
pareto fronts in the data.
Args:
val: Value passed through the while loop. Here, it is
a tuple containing two values. The indexes of all
solutions to keep and the indexes of the last
computed front.
Returns:
The updated values to pass through the while loop. Updated
number of solutions and updated front indexes.
"""
to_keep_index, _ = val
# mask the individual that are already kept
front_index = compute_masked_pareto_front(
candidate_fitnesses, mask=to_keep_index
)
# Add new indexes
to_keep_index = to_keep_index + front_index
# Update front & number of solutions
return to_keep_index, front_index
def condition_fn_1(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
"""Gives condition to stop the while loop. Makes sure the
the number of solution is smaller than the maximum size
of the population.
Args:
val: Value passed through the while loop. Here, it is
a tuple containing two values. The indexes of all
solutions to keep and the indexes of the last
computed front.
Returns:
Returns True if we have reached the maximum number of
solutions we can keep in the population.
"""
to_keep_index, _ = val
return sum(to_keep_index) < self.size # type: ignore
# get indexes of all first successive fronts and indexes of the last front
to_keep_index, front_index = jax.lax.while_loop(
condition_fn_1,
compute_current_front,
(
jnp.zeros(num_candidates, dtype=bool),
jnp.zeros(num_candidates, dtype=bool),
),
)
# remove the indexes of the last front - gives first indexes to keep
new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
new_index = new_index * (~front_index)
to_keep_index = new_index > 0
# Compute crowding distances
crowding_distances = self._compute_crowding_distances(
candidate_fitnesses, ~front_index
)
crowding_distances = crowding_distances * (front_index)
highest_dist = jnp.argsort(crowding_distances)
def add_to_front(val: Tuple[jnp.ndarray, float]) -> Tuple[jnp.ndarray, Any]:
"""Add the individual with a given distance to the front.
A index is incremented to get the highest from the non
selected individuals.
Args:
val: a tuple of two elements. A boolean vector with the positions that
will be kept, and a cursor with the number of individuals already
added during this process.
Returns:
The updated tuple, with the new booleans and the number of
added elements.
"""
front_index, num = val
front_index = front_index.at[highest_dist[-num]].set(True)
num = num + 1
val = front_index, num
return val
def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
"""Gives condition to stop the while loop. Makes sure the
the number of solution is smaller than the maximum size
of the population."""
front_index, _ = val
return sum(to_keep_index + front_index) < self.size # type: ignore
# add the individuals with the highest distances
front_index, _num = jax.lax.while_loop(
condition_fn_2,
add_to_front,
(jnp.zeros(num_candidates, dtype=bool), 0),
)
# update index
to_keep_index = to_keep_index + front_index
# go from boolean vector to indices - offset by 1
indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index
# get rid of the zeros (that correspond to the False from the mask)
fake_indice = num_candidates + 1 # bigger than all the other indices
indices = jnp.where(indices == 0, x=fake_indice, y=indices)
# sort the indices to remove the fake indices
indices = jnp.sort(indices)[: self.size]
# remove the offset
indices = indices - 1
# keep only the survivors
new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
new_scores = candidate_fitnesses[indices]
new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores)
return new_repertoire # type: ignore
add(self, batch_of_genotypes, batch_of_fitnesses)
¶
Implements the repertoire addition rules.
The population is sorted in successive pareto front. The first one is the global pareto front. The second one is the pareto front of the population where the first pareto front has been removed, etc...
The successive pareto fronts are kept until the moment where adding a full pareto front would exceed the population size.
To decide the survival of this pareto front, a crowding distance is computed in order to keep individuals that are spread in this last pareto front. Hence, the individuals with the biggest crowding distances are added until the population size is reached.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/nsga2_repertoire.py
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> NSGA2Repertoire:
"""Implements the repertoire addition rules.
The population is sorted in successive pareto front. The first one
is the global pareto front. The second one is the pareto front of the
population where the first pareto front has been removed, etc...
The successive pareto fronts are kept until the moment where adding a
full pareto front would exceed the population size.
To decide the survival of this pareto front, a crowding distance is
computed in order to keep individuals that are spread in this last pareto
front. Hence, the individuals with the biggest crowding distances are
added until the population size is reached.
Args:
batch_of_genotypes: new genotypes that we try to add.
batch_of_fitnesses: fitness of those new genotypes.
Returns:
The updated repertoire.
"""
# All the candidates
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))
first_leaf = jax.tree_util.tree_leaves(candidates)[0]
num_candidates = first_leaf.shape[0]
def compute_current_front(
val: Tuple[jnp.ndarray, jnp.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Body function for the while loop. Computes the successive
pareto fronts in the data.
Args:
val: Value passed through the while loop. Here, it is
a tuple containing two values. The indexes of all
solutions to keep and the indexes of the last
computed front.
Returns:
The updated values to pass through the while loop. Updated
number of solutions and updated front indexes.
"""
to_keep_index, _ = val
# mask the individual that are already kept
front_index = compute_masked_pareto_front(
candidate_fitnesses, mask=to_keep_index
)
# Add new indexes
to_keep_index = to_keep_index + front_index
# Update front & number of solutions
return to_keep_index, front_index
def condition_fn_1(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
"""Gives condition to stop the while loop. Makes sure the
the number of solution is smaller than the maximum size
of the population.
Args:
val: Value passed through the while loop. Here, it is
a tuple containing two values. The indexes of all
solutions to keep and the indexes of the last
computed front.
Returns:
Returns True if we have reached the maximum number of
solutions we can keep in the population.
"""
to_keep_index, _ = val
return sum(to_keep_index) < self.size # type: ignore
# get indexes of all first successive fronts and indexes of the last front
to_keep_index, front_index = jax.lax.while_loop(
condition_fn_1,
compute_current_front,
(
jnp.zeros(num_candidates, dtype=bool),
jnp.zeros(num_candidates, dtype=bool),
),
)
# remove the indexes of the last front - gives first indexes to keep
new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
new_index = new_index * (~front_index)
to_keep_index = new_index > 0
# Compute crowding distances
crowding_distances = self._compute_crowding_distances(
candidate_fitnesses, ~front_index
)
crowding_distances = crowding_distances * (front_index)
highest_dist = jnp.argsort(crowding_distances)
def add_to_front(val: Tuple[jnp.ndarray, float]) -> Tuple[jnp.ndarray, Any]:
"""Add the individual with a given distance to the front.
A index is incremented to get the highest from the non
selected individuals.
Args:
val: a tuple of two elements. A boolean vector with the positions that
will be kept, and a cursor with the number of individuals already
added during this process.
Returns:
The updated tuple, with the new booleans and the number of
added elements.
"""
front_index, num = val
front_index = front_index.at[highest_dist[-num]].set(True)
num = num + 1
val = front_index, num
return val
def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
"""Gives condition to stop the while loop. Makes sure the
the number of solution is smaller than the maximum size
of the population."""
front_index, _ = val
return sum(to_keep_index + front_index) < self.size # type: ignore
# add the individuals with the highest distances
front_index, _num = jax.lax.while_loop(
condition_fn_2,
add_to_front,
(jnp.zeros(num_candidates, dtype=bool), 0),
)
# update index
to_keep_index = to_keep_index + front_index
# go from boolean vector to indices - offset by 1
indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index
# get rid of the zeros (that correspond to the False from the mask)
fake_indice = num_candidates + 1 # bigger than all the other indices
indices = jnp.where(indices == 0, x=fake_indice, y=indices)
# sort the indices to remove the fake indices
indices = jnp.sort(indices)[: self.size]
# remove the offset
indices = indices - 1
# keep only the survivors
new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
new_scores = candidate_fitnesses[indices]
new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores)
return new_repertoire # type: ignore
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/nsga2_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
repertoire
¶
This file contains util functions and a class to define a repertoire, used to store individuals in the MAP-Elites algorithm as well as several variants.
Repertoire (PyTreeNode, ABC)
dataclass
¶
Abstract class for any repertoire of genotypes.
We decided not to add the attributes Genotypes even if it will be shared by all children classes because we want to keep the parent classes explicit and transparent.
Source code in qdax/core/containers/repertoire.py
class Repertoire(flax.struct.PyTreeNode, ABC):
"""Abstract class for any repertoire of genotypes.
We decided not to add the attributes Genotypes even if
it will be shared by all children classes because we want
to keep the parent classes explicit and transparent.
"""
@abstractclassmethod
def init(cls) -> Repertoire: # noqa: N805
"""Create a repertoire."""
pass
@abstractmethod
def sample(
self,
random_key: RNGKey,
num_samples: int,
) -> Genotype:
"""Sample genotypes from the repertoire.
Args:
random_key: a random key to handle stochasticity.
num_samples: the number of genotypes to sample.
Returns:
The sample of genotypes.
"""
pass
@abstractmethod
def add(self) -> Repertoire:
"""Implements the rule to add new genotypes to a
repertoire.
Returns:
The udpated repertoire.
"""
pass
init()
classmethod
¶
Create a repertoire.
Source code in qdax/core/containers/repertoire.py
@abstractclassmethod
def init(cls) -> Repertoire: # noqa: N805
"""Create a repertoire."""
pass
sample(self, random_key, num_samples)
¶
Sample genotypes from the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/repertoire.py
@abstractmethod
def sample(
self,
random_key: RNGKey,
num_samples: int,
) -> Genotype:
"""Sample genotypes from the repertoire.
Args:
random_key: a random key to handle stochasticity.
num_samples: the number of genotypes to sample.
Returns:
The sample of genotypes.
"""
pass
add(self)
¶
Implements the rule to add new genotypes to a repertoire.
Returns: |
|
---|
Source code in qdax/core/containers/repertoire.py
@abstractmethod
def add(self) -> Repertoire:
"""Implements the rule to add new genotypes to a
repertoire.
Returns:
The udpated repertoire.
"""
pass
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
spea2_repertoire
¶
SPEA2Repertoire (GARepertoire)
dataclass
¶
Repertoire used for the SPEA2 algorithm.
Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, save, sample.
Source code in qdax/core/containers/spea2_repertoire.py
class SPEA2Repertoire(GARepertoire):
"""Repertoire used for the SPEA2 algorithm.
Inherits from the GARepertoire. The data stored are the genotypes
and there fitness. Several functions are inherited from GARepertoire,
including size, save, sample.
"""
num_neighbours: int = flax.struct.field(pytree_node=False)
@jax.jit
def _compute_strength_scores(self, batch_of_fitnesses: Fitness) -> jnp.ndarray:
"""Compute the strength scores (defined for a solution by the number of
solutions dominating it plus the inverse of the density of solution in the
fitness space).
Args:
batch_of_fitnesses: a batch of fitness vectors.
Returns:
Strength score of each solution corresponding to the fitnesses.
"""
fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses), axis=0)
# dominating solutions
dominates = jnp.all(
(fitnesses - jnp.expand_dims(fitnesses, axis=1)) > 0, axis=-1
)
strength_scores = jnp.sum(dominates, axis=1)
# density
distance_matrix = jnp.sum(
(fitnesses - jnp.expand_dims(fitnesses, axis=1)) ** 2, axis=-1
)
densities = jnp.sum(
jnp.sort(distance_matrix, axis=1)[:, : self.num_neighbours + 1], axis=1
)
# sum both terms
strength_scores = strength_scores + 1 / (1 + densities)
strength_scores = jnp.nan_to_num(strength_scores, nan=self.size + 2)
return strength_scores
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_fitnesses: Fitness,
) -> SPEA2Repertoire:
"""Updates the population with the new solutions.
To decide which individuals to keep, we count, for each solution,
the number of solutions by which tey are dominated. We keep only
the solutions that are the less dominated ones.
Args:
batch_of_genotypes: genotypes of the new individuals that are
considered to be added to the population.
batch_of_fitnesses: their corresponding fitnesses.
Returns:
Updated repertoire.
"""
# All the candidates
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))
# compute strength score for all solutions
strength_scores = self._compute_strength_scores(batch_of_fitnesses)
# sort the strengths (the smaller the better (sic, respect paper's notation))
indices = jnp.argsort(strength_scores)[: self.size]
# keep the survivors
new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
new_fitnesses = candidates_fitnesses[indices]
new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses)
return new_repertoire # type: ignore
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
num_neighbours: int,
) -> GARepertoire:
"""Initializes the repertoire.
Start with default values and adds a first batch of genotypes
to the repertoire.
Args:
genotypes: first batch of genotypes
fitnesses: corresponding fitnesses
population_size: size of the population we want to evolve
Returns:
An initial repertoire.
"""
# create default fitnesses
default_fitnesses = -jnp.inf * jnp.ones(
shape=(population_size, fitnesses.shape[-1])
)
# create default genotypes
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)
# create an initial repertoire with those default values
repertoire = cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
num_neighbours=num_neighbours,
)
new_repertoire = repertoire.add(genotypes, fitnesses)
return new_repertoire # type: ignore
add(self, batch_of_genotypes, batch_of_fitnesses)
¶
Updates the population with the new solutions.
To decide which individuals to keep, we count, for each solution, the number of solutions by which tey are dominated. We keep only the solutions that are the less dominated ones.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/spea2_repertoire.py
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_fitnesses: Fitness,
) -> SPEA2Repertoire:
"""Updates the population with the new solutions.
To decide which individuals to keep, we count, for each solution,
the number of solutions by which tey are dominated. We keep only
the solutions that are the less dominated ones.
Args:
batch_of_genotypes: genotypes of the new individuals that are
considered to be added to the population.
batch_of_fitnesses: their corresponding fitnesses.
Returns:
Updated repertoire.
"""
# All the candidates
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))
# compute strength score for all solutions
strength_scores = self._compute_strength_scores(batch_of_fitnesses)
# sort the strengths (the smaller the better (sic, respect paper's notation))
indices = jnp.argsort(strength_scores)[: self.size]
# keep the survivors
new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
new_fitnesses = candidates_fitnesses[indices]
new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses)
return new_repertoire # type: ignore
init(genotypes, fitnesses, population_size, num_neighbours)
classmethod
¶
Initializes the repertoire.
Start with default values and adds a first batch of genotypes to the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/spea2_repertoire.py
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
num_neighbours: int,
) -> GARepertoire:
"""Initializes the repertoire.
Start with default values and adds a first batch of genotypes
to the repertoire.
Args:
genotypes: first batch of genotypes
fitnesses: corresponding fitnesses
population_size: size of the population we want to evolve
Returns:
An initial repertoire.
"""
# create default fitnesses
default_fitnesses = -jnp.inf * jnp.ones(
shape=(population_size, fitnesses.shape[-1])
)
# create default genotypes
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)
# create an initial repertoire with those default values
repertoire = cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
num_neighbours=num_neighbours,
)
new_repertoire = repertoire.add(genotypes, fitnesses)
return new_repertoire # type: ignore
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/spea2_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
uniform_replacement_archive
¶
UniformReplacementArchive (Archive)
dataclass
¶
Stores jnp.ndarray and use a uniform replacement when the maximum size is reached.
Instead of replacing elements in a FIFO manner, like the Archive, this implementation removes elements uniformly to replace them by the newly added ones.
Most methods are inherited from Archive.
Source code in qdax/core/containers/uniform_replacement_archive.py
class UniformReplacementArchive(Archive):
"""Stores jnp.ndarray and use a uniform replacement when the
maximum size is reached.
Instead of replacing elements in a FIFO manner, like the Archive,
this implementation removes elements uniformly to replace them by
the newly added ones.
Most methods are inherited from Archive.
"""
random_key: RNGKey
@classmethod
def create( # type: ignore
cls,
acceptance_threshold: float,
state_descriptor_size: int,
max_size: int,
random_key: RNGKey,
) -> Archive:
"""Create an Archive instance.
This class method provides a convenient way to create the archive while
keeping the __init__ function for more general way to init an archive.
Args:
acceptance_threshold: the minimal distance to a stored descriptor to
be respected for a new descriptor to be added.
state_descriptor_size: the number of elements in a state descriptor.
max_size: the maximal size of the archive. In case of overflow, previous
elements are replaced by new ones. Defaults to 80000.
random_key: a key to handle random operations. Defaults to key with
seed = 0.
Returns:
A newly initialized archive.
"""
archive = super().create(
acceptance_threshold,
state_descriptor_size,
max_size,
)
return archive.replace(random_key=random_key) # type: ignore
@jax.jit
def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive:
"""Insert a single element.
If the archive is not full yet, the new element replaces a fake
border, if it is full, it replaces a random element from the archive.
Args:
state_descriptor: state descriptor to be added.
Returns:
Return the archive with the newly added element."""
new_current_position = self.current_position + 1
is_full = new_current_position >= self.max_size
random_key, subkey = jax.random.split(self.random_key)
random_index = jax.random.randint(
subkey, shape=(1,), minval=0, maxval=self.max_size
)
index = jnp.where(condition=is_full, x=random_index, y=new_current_position)
new_data = self.data.at[index].set(state_descriptor)
return self.replace( # type: ignore
current_position=new_current_position, data=new_data, random_key=random_key
)
create(acceptance_threshold, state_descriptor_size, max_size, random_key)
classmethod
¶
Create an Archive instance.
This class method provides a convenient way to create the archive while keeping the init function for more general way to init an archive.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/uniform_replacement_archive.py
@classmethod
def create( # type: ignore
cls,
acceptance_threshold: float,
state_descriptor_size: int,
max_size: int,
random_key: RNGKey,
) -> Archive:
"""Create an Archive instance.
This class method provides a convenient way to create the archive while
keeping the __init__ function for more general way to init an archive.
Args:
acceptance_threshold: the minimal distance to a stored descriptor to
be respected for a new descriptor to be added.
state_descriptor_size: the number of elements in a state descriptor.
max_size: the maximal size of the archive. In case of overflow, previous
elements are replaced by new ones. Defaults to 80000.
random_key: a key to handle random operations. Defaults to key with
seed = 0.
Returns:
A newly initialized archive.
"""
archive = super().create(
acceptance_threshold,
state_descriptor_size,
max_size,
)
return archive.replace(random_key=random_key) # type: ignore
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/uniform_replacement_archive.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
unstructured_repertoire
¶
UnstructuredRepertoire (PyTreeNode)
dataclass
¶
Class for the unstructured repertoire in Map Elites.
Parameters: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
class UnstructuredRepertoire(flax.struct.PyTreeNode):
"""
Class for the unstructured repertoire in Map Elites.
Args:
genotypes: a PyTree containing all the genotypes in the repertoire ordered
by the centroids. Each leaf has a shape (num_centroids, num_features). The
PyTree can be a simple Jax array or a more complex nested structure such
as to represent parameters of neural network in Flax.
fitnesses: an array that contains the fitness of solutions in each cell of the
repertoire, ordered by centroids. The array shape is (num_centroids,).
descriptors: an array that contains the descriptors of solutions in each cell
of the repertoire, ordered by centroids. The array shape
is (num_centroids, num_descriptors).
centroids: an array the contains the centroids of the tesselation. The array
shape is (num_centroids, num_descriptors).
observations: observations that the genotype gathered in the environment.
"""
genotypes: Genotype
fitnesses: Fitness
descriptors: Descriptor
observations: Observation
l_value: jnp.ndarray
max_size: int = flax.struct.field(pytree_node=False)
def get_maximal_size(self) -> int:
"""Returns the maximal number of individuals in the repertoire."""
return self.max_size
def get_number_genotypes(self) -> jnp.ndarray:
"""Returns the number of genotypes in the repertoire."""
return jnp.sum(self.fitnesses != -jnp.inf)
def save(self, path: str = "./") -> None:
"""Saves the grid on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _unravel_pytree = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "observations.npy", self.observations)
jnp.save(path + "l_value.npy", self.l_value)
jnp.save(path + "max_size.npy", self.max_size)
@classmethod
def load(
cls, reconstruction_fn: Callable, path: str = "./"
) -> UnstructuredRepertoire:
"""Loads an unstructured repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
An unstructured repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
observations = jnp.load(path + "observations.npy")
l_value = jnp.load(path + "l_value.npy")
max_size = int(jnp.load(path + "max_size.npy").item())
return UnstructuredRepertoire(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
observations=observations,
l_value=l_value,
max_size=max_size,
)
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_observations: Observation,
) -> UnstructuredRepertoire:
"""Adds a batch of genotypes to the repertoire.
Args:
batch_of_genotypes: genotypes of the individuals to be considered
for addition in the repertoire.
batch_of_descriptors: associated descriptors.
batch_of_fitnesses: associated fitness.
batch_of_observations: associated observations.
Returns:
A new unstructured repertoire where the relevant individuals have been
added.
"""
# We need to replace all the descriptors that are not filled with jnp inf
filtered_descriptors = jnp.where(
jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
self.descriptors,
)
batch_of_indices, batch_of_distances = get_cells_indices(
batch_of_descriptors, filtered_descriptors, 2
)
# Save the second-nearest neighbours to check a condition
second_neighbours = batch_of_distances.at[..., 1].get()
# Keep the Nearest neighbours
batch_of_indices = batch_of_indices.at[..., 0].get()
# Keep the Nearest neighbours
batch_of_distances = batch_of_distances.at[..., 0].get()
# We remove individuals that are too close to the second nn.
# This avoids having clusters of individuals after adding them.
not_novel_enough = jnp.where(
jnp.squeeze(second_neighbours <= self.l_value), True, False
)
# batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1)
# TODO: Doesn't Work if Archive is full. Need to use the closest individuals
# in that case.
empty_indexes = jnp.squeeze(
jnp.nonzero(
jnp.where(jnp.isinf(self.fitnesses), 1, 0),
size=batch_of_indices.shape[0],
fill_value=-1,
)[0]
)
batch_of_indices = jnp.where(
jnp.squeeze(batch_of_distances <= self.l_value),
jnp.squeeze(batch_of_indices),
-1,
)
# We get all the indices of the empty bds first and then the filled ones
# (because of -1)
sorted_bds = jax.lax.top_k(
-1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
)[1]
batch_of_indices = jnp.where(
jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value),
batch_of_indices.at[sorted_bds].get(),
empty_indexes,
)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
# ReIndexing of all the inputs to the correct sorted way
batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get()
batch_of_genotypes = jax.tree_map(
lambda x: x.at[sorted_bds].get(), batch_of_genotypes
)
batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get()
batch_of_observations = batch_of_observations.at[sorted_bds].get()
not_novel_enough = not_novel_enough.at[sorted_bds].get()
# Check to find Individuals with same BD within the Batch
keep_indiv = jax.jit(
jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
)(
batch_of_descriptors.squeeze(),
jnp.arange(
0, batch_of_descriptors.shape[0], 1
), # keep track of where we are in the batch to assure right comparisons
batch_of_descriptors.squeeze(),
batch_of_fitnesses.squeeze(),
self.l_value,
)
keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))
# get fitness segment max
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices.astype(jnp.int32).squeeze(),
num_segments=self.max_size,
)
cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)
# put dominated fitness to -jnp.inf
batch_of_fitnesses = jnp.where(
batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
)
# get addition condition
grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0)
addition_condition = batch_of_fitnesses > current_fitnesses
addition_condition = jnp.logical_and(
addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
)
# assign fake position when relevant : num_centroids is out of bounds
batch_of_indices = jnp.where(
addition_condition,
x=batch_of_indices,
y=self.max_size,
)
# create new grid
new_grid_genotypes = jax.tree_map(
lambda grid_genotypes, new_genotypes: grid_genotypes.at[
batch_of_indices.squeeze()
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
batch_of_fitnesses.squeeze()
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
batch_of_descriptors.squeeze()
)
new_observations = self.observations.at[batch_of_indices.squeeze()].set(
batch_of_observations.squeeze()
)
return UnstructuredRepertoire(
genotypes=new_grid_genotypes,
fitnesses=new_fitnesses.squeeze(),
descriptors=new_descriptors.squeeze(),
observations=new_observations.squeeze(),
l_value=self.l_value,
max_size=self.max_size,
)
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""
random_key, sub_key = jax.random.split(random_key)
grid_empty = self.fitnesses == -jnp.inf
p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty)
samples = jax.tree_map(
lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
observations: Observation,
l_value: jnp.ndarray,
max_size: int,
) -> UnstructuredRepertoire:
"""Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
observations: observations experienced in the evaluation task.
l_value: threshold distance of the repertoire.
max_size: maximal size of the container
Returns:
an initialized unstructured repertoire.
"""
# Initialize grid with default values
default_fitnesses = -jnp.inf * jnp.ones(shape=max_size)
default_genotypes = jax.tree_map(
lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
genotypes,
)
default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))
default_observations = jnp.full(
shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan
)
repertoire = UnstructuredRepertoire(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
observations=default_observations,
l_value=l_value,
max_size=max_size,
)
return repertoire.add( # type: ignore
genotypes, descriptors, fitnesses, observations
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/containers/unstructured_repertoire.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
get_maximal_size(self)
¶
Returns the maximal number of individuals in the repertoire.
Source code in qdax/core/containers/unstructured_repertoire.py
def get_maximal_size(self) -> int:
"""Returns the maximal number of individuals in the repertoire."""
return self.max_size
get_number_genotypes(self)
¶
Returns the number of genotypes in the repertoire.
Source code in qdax/core/containers/unstructured_repertoire.py
def get_number_genotypes(self) -> jnp.ndarray:
"""Returns the number of genotypes in the repertoire."""
return jnp.sum(self.fitnesses != -jnp.inf)
save(self, path='./')
¶
Saves the grid on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.
Parameters: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
def save(self, path: str = "./") -> None:
"""Saves the grid on disk in the form of .npy files.
Flattens the genotypes to store it with .npy format. Supposes that
a user will have access to the reconstruction function when loading
the genotypes.
Args:
path: Path where the data will be saved. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _unravel_pytree = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
# save data
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "fitnesses.npy", self.fitnesses)
jnp.save(path + "descriptors.npy", self.descriptors)
jnp.save(path + "observations.npy", self.observations)
jnp.save(path + "l_value.npy", self.l_value)
jnp.save(path + "max_size.npy", self.max_size)
load(reconstruction_fn, path='./')
classmethod
¶
Loads an unstructured repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
@classmethod
def load(
cls, reconstruction_fn: Callable, path: str = "./"
) -> UnstructuredRepertoire:
"""Loads an unstructured repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
An unstructured repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
descriptors = jnp.load(path + "descriptors.npy")
observations = jnp.load(path + "observations.npy")
l_value = jnp.load(path + "l_value.npy")
max_size = int(jnp.load(path + "max_size.npy").item())
return UnstructuredRepertoire(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
observations=observations,
l_value=l_value,
max_size=max_size,
)
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_observations)
¶
Adds a batch of genotypes to the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
@jax.jit
def add(
self,
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_observations: Observation,
) -> UnstructuredRepertoire:
"""Adds a batch of genotypes to the repertoire.
Args:
batch_of_genotypes: genotypes of the individuals to be considered
for addition in the repertoire.
batch_of_descriptors: associated descriptors.
batch_of_fitnesses: associated fitness.
batch_of_observations: associated observations.
Returns:
A new unstructured repertoire where the relevant individuals have been
added.
"""
# We need to replace all the descriptors that are not filled with jnp inf
filtered_descriptors = jnp.where(
jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
self.descriptors,
)
batch_of_indices, batch_of_distances = get_cells_indices(
batch_of_descriptors, filtered_descriptors, 2
)
# Save the second-nearest neighbours to check a condition
second_neighbours = batch_of_distances.at[..., 1].get()
# Keep the Nearest neighbours
batch_of_indices = batch_of_indices.at[..., 0].get()
# Keep the Nearest neighbours
batch_of_distances = batch_of_distances.at[..., 0].get()
# We remove individuals that are too close to the second nn.
# This avoids having clusters of individuals after adding them.
not_novel_enough = jnp.where(
jnp.squeeze(second_neighbours <= self.l_value), True, False
)
# batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1)
# TODO: Doesn't Work if Archive is full. Need to use the closest individuals
# in that case.
empty_indexes = jnp.squeeze(
jnp.nonzero(
jnp.where(jnp.isinf(self.fitnesses), 1, 0),
size=batch_of_indices.shape[0],
fill_value=-1,
)[0]
)
batch_of_indices = jnp.where(
jnp.squeeze(batch_of_distances <= self.l_value),
jnp.squeeze(batch_of_indices),
-1,
)
# We get all the indices of the empty bds first and then the filled ones
# (because of -1)
sorted_bds = jax.lax.top_k(
-1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
)[1]
batch_of_indices = jnp.where(
jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value),
batch_of_indices.at[sorted_bds].get(),
empty_indexes,
)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
# ReIndexing of all the inputs to the correct sorted way
batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get()
batch_of_genotypes = jax.tree_map(
lambda x: x.at[sorted_bds].get(), batch_of_genotypes
)
batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get()
batch_of_observations = batch_of_observations.at[sorted_bds].get()
not_novel_enough = not_novel_enough.at[sorted_bds].get()
# Check to find Individuals with same BD within the Batch
keep_indiv = jax.jit(
jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
)(
batch_of_descriptors.squeeze(),
jnp.arange(
0, batch_of_descriptors.shape[0], 1
), # keep track of where we are in the batch to assure right comparisons
batch_of_descriptors.squeeze(),
batch_of_fitnesses.squeeze(),
self.l_value,
)
keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))
# get fitness segment max
best_fitnesses = jax.ops.segment_max(
batch_of_fitnesses,
batch_of_indices.astype(jnp.int32).squeeze(),
num_segments=self.max_size,
)
cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)
# put dominated fitness to -jnp.inf
batch_of_fitnesses = jnp.where(
batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
)
# get addition condition
grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0)
addition_condition = batch_of_fitnesses > current_fitnesses
addition_condition = jnp.logical_and(
addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
)
# assign fake position when relevant : num_centroids is out of bounds
batch_of_indices = jnp.where(
addition_condition,
x=batch_of_indices,
y=self.max_size,
)
# create new grid
new_grid_genotypes = jax.tree_map(
lambda grid_genotypes, new_genotypes: grid_genotypes.at[
batch_of_indices.squeeze()
].set(new_genotypes),
self.genotypes,
batch_of_genotypes,
)
# compute new fitness and descriptors
new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
batch_of_fitnesses.squeeze()
)
new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
batch_of_descriptors.squeeze()
)
new_observations = self.observations.at[batch_of_indices.squeeze()].set(
batch_of_observations.squeeze()
)
return UnstructuredRepertoire(
genotypes=new_grid_genotypes,
fitnesses=new_fitnesses.squeeze(),
descriptors=new_descriptors.squeeze(),
observations=new_observations.squeeze(),
l_value=self.l_value,
max_size=self.max_size,
)
sample(self, random_key, num_samples)
¶
Sample elements in the repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample elements in the repertoire.
Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled
Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""
random_key, sub_key = jax.random.split(random_key)
grid_empty = self.fitnesses == -jnp.inf
p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty)
samples = jax.tree_map(
lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
self.genotypes,
)
return samples, random_key
init(genotypes, fitnesses, descriptors, observations, l_value, max_size)
classmethod
¶
Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
@classmethod
def init(
cls,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
observations: Observation,
l_value: jnp.ndarray,
max_size: int,
) -> UnstructuredRepertoire:
"""Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: fitness of the initial genotypes of shape (batch_size,)
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
observations: observations experienced in the evaluation task.
l_value: threshold distance of the repertoire.
max_size: maximal size of the container
Returns:
an initialized unstructured repertoire.
"""
# Initialize grid with default values
default_fitnesses = -jnp.inf * jnp.ones(shape=max_size)
default_genotypes = jax.tree_map(
lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
genotypes,
)
default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))
default_observations = jnp.full(
shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan
)
repertoire = UnstructuredRepertoire(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
observations=default_observations,
l_value=l_value,
max_size=max_size,
)
return repertoire.add( # type: ignore
genotypes, descriptors, fitnesses, observations
)
get_cells_indices(batch_of_descriptors, centroids, k_nn)
¶
Returns the array of cells indices for a batch of descriptors given the centroids of the grid.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/containers/unstructured_repertoire.py
@partial(jax.jit, static_argnames=("k_nn",))
def get_cells_indices(
batch_of_descriptors: Descriptor, centroids: Centroid, k_nn: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Returns the array of cells indices for a batch of descriptors
given the centroids of the grid.
Args:
batch_of_descriptors: a batch of descriptors
of shape (batch_size, num_descriptors)
centroids: centroids array of shape (num_centroids, num_descriptors)
Returns:
the indices of the centroids corresponding to each vector of descriptors
in the batch with shape (batch_size,)
"""
def _get_cells_indices(
_descriptors: jnp.ndarray,
_centroids: jnp.ndarray,
_k_nn: int,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Inner function.
descriptors of shape (1, num_descriptors)
centroids of shape (num_centroids, num_descriptors)
"""
distances = jax.vmap(jnp.linalg.norm)(_descriptors - _centroids)
# Negating distances because we want the smallest ones
min_dist, min_args = jax.lax.top_k(-1 * distances, _k_nn)
return min_args, -1 * min_dist
func = jax.vmap(
_get_cells_indices,
in_axes=(
0,
None,
None,
),
)
return func(batch_of_descriptors, centroids, k_nn) # type: ignore
intra_batch_comp(normed, current_index, normed_all, eval_scores, l_value)
¶
Function to know if an individual should be kept or not.
Source code in qdax/core/containers/unstructured_repertoire.py
@jax.jit
def intra_batch_comp(
normed: jnp.ndarray,
current_index: jnp.ndarray,
normed_all: jnp.ndarray,
eval_scores: jnp.ndarray,
l_value: jnp.ndarray,
) -> jnp.ndarray:
"""Function to know if an individual should be kept or not."""
# Check for individuals that are Nans, we remove them at the end
not_existent = jnp.where((jnp.isnan(normed)).any(), True, False)
# Fill in Nans to do computations
normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed)
eval_scores = jnp.where(
jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores
)
# If we do not use a fitness (i.e same fitness everywhere), we create a virtual
# fitness function to add individuals with the same bd
additional_score = jnp.where(
jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0
)
additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0])
# Add scores to empty individuals
eval_scores = jnp.where(
jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores
)
# Virtual eval_scores
eval_scores = eval_scores + additional_scores
# For each point we check what other points are the closest ones.
knn_relevant_scores, knn_relevant_indices = jax.lax.top_k(
-1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0]
)
# We negated the scores to use top_k so we reverse it.
knn_relevant_scores = knn_relevant_scores * -1
# Check if the individual is close enough to compare (under l-value)
fitness = jnp.where(jnp.squeeze(knn_relevant_scores < l_value), True, False)
# We want to eliminate the same individual (distance 0)
fitness = jnp.where(knn_relevant_indices == current_index, False, fitness)
current_fitness = jnp.squeeze(
eval_scores.at[knn_relevant_indices.at[0].get()].get()
)
# Is the fitness of the other individual higher?
# If both are True then we discard the current individual since this individual
# would be replaced by the better one.
discard_indiv = jnp.logical_and(
jnp.where(
eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False
),
fitness,
).any()
# Discard Individuals with Nans as their BD (mainly for the readdition where we
# have NaN bds)
discard_indiv = jnp.logical_or(discard_indiv, not_existent)
# Negate to know if we keep the individual
return jnp.logical_not(discard_indiv)