Containers

qdax.core.containers special

archive

Defines an unstructured archive and a euclidean novelty scorer.

Archive (PyTreeNode) dataclass

Stores jnp.ndarray in a way that makes insertion jittable.

An example of use of the archive is the algorithm QDPG: state descriptors are stored in this archive and a novelty scorer compares new state desciptors to the state descriptors stored in this archive.

Note: notations suppose that the elements are called state desciptors. If we where to use this structure for another application, it would be better to change the variables name for another one. Does not seem necessary at the moment though.

Source code in qdax/core/containers/archive.py
class Archive(PyTreeNode):
    """Stores jnp.ndarray in a way that makes insertion jittable.

    An example of use of the archive is the algorithm QDPG: state
    descriptors are stored in this archive and a novelty scorer compares
    new state desciptors to the state descriptors stored in this archive.

    Note: notations suppose that the elements are called state desciptors.
    If we where to use this structure for another application, it would be
    better to change the variables name for another one. Does not seem
    necessary at the moment though.
    """

    data: jnp.ndarray  # initialised with nan everywhere
    current_position: int
    acceptance_threshold: float
    state_descriptor_size: int
    max_size: int

    @property
    def size(self) -> float:
        """Compute the number of state descriptors stored in the archive.

        Returns:
            Size of the archive.
        """
        # remove fake borders
        fake_data = jnp.isnan(self.data)

        # count number of real data
        return sum(~fake_data)

    @classmethod
    def create(
        cls,
        acceptance_threshold: float,
        state_descriptor_size: int,
        max_size: int,
    ) -> Archive:
        """Create an Archive instance.

        This class method provides a convenient way to create the archive while
        keeping the __init__ function for more general way to init an archive.

        Args:
            acceptance_threshold: the minimal distance to a stored descriptor to
                be respected for a new descriptor to be added.
            state_descriptor_size: the number of elements in a state descriptor.
            max_size: the maximal size of the archive. In case of overflow, previous
                elements are replaced by new ones. Defaults to 80000.

        Returns:
            A newly initialized archive.
        """
        init_data = jnp.ones((max_size, state_descriptor_size)) * jnp.nan
        return cls(  # type: ignore
            data=init_data,
            current_position=0,
            acceptance_threshold=acceptance_threshold,
            state_descriptor_size=state_descriptor_size,
            max_size=max_size,
        )

    @jax.jit
    def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive:
        """Insert a single element.

        If the archive is not full yet, the new element replaces a fake
        border, if it is full, it replaces the element that was inserted
        first in the archive.

        Args:
            state_descriptor: state descriptor to be added.

        Returns:
            Return the archive with the newly added element.
        """
        new_current_position = self.current_position + 1
        new_data = jax.lax.dynamic_update_slice_in_dim(
            self.data,
            state_descriptor.reshape(1, -1),
            start_index=self.current_position % self.max_size,
            axis=0,
        )

        return self.replace(  # type: ignore
            current_position=new_current_position, data=new_data
        )

    @jax.jit
    def _conditioned_single_insertion(
        self, condition: bool, state_descriptor: jnp.ndarray
    ) -> Tuple[Archive, jnp.ndarray]:
        """Inserts a single element under a condition.

        The function also retrieves the added elements.

        Args:
            condition: condition for being added in the archive.
            state_descriptor: state descriptor to be added under the
                given condition.

        Returns:
            The new archive and the added elements.
        """

        def true_fun(
            archive: Archive, state_descriptor: jnp.ndarray
        ) -> Tuple[Archive, jnp.ndarray]:
            return archive._single_insertion(state_descriptor), state_descriptor

        def false_fun(
            archive: Archive, state_descriptor: jnp.ndarray
        ) -> Tuple[Archive, jnp.ndarray]:
            return archive, jnp.ones_like(state_descriptor) * jnp.nan

        return jax.lax.cond(  # type: ignore
            condition, true_fun, false_fun, self, state_descriptor
        )

    @jax.jit
    def insert(self, state_descriptors: jnp.ndarray) -> Archive:
        """Tries to insert a batch of state descriptors in the archive.

        1. First, look at the distance of each new state descriptor with the
        already stored ones.
        2. Then, scan the state descriptors, check the distance with
        the descriptors inserted during the scan.
        3. If the state descriptor verified the first condition (not too close
        to a state descriptor in the old archive) and the second (not too close
        from a state descriptor that has just been added), then it is added
        to the archive.

        Note 1: the archive has a fixed size, hence, in case of overflow, the
        first elements added are removed first (FIFO style).
        Note 2: keep in mind that fake descriptors are used to help keep the size
        constant. Those correspond to a descriptor very far away from the typical
        values of the problem at hand.

        Args:
            state_descriptors: state descriptors to be added.

        Returns:
            New archive updated with the state descriptors.
        """
        state_descriptors = state_descriptors.reshape((-1, state_descriptors.shape[-1]))

        # get nearest neigbor for each new state descriptor
        values, _indices = knn(self.data, state_descriptors, 1)

        # get indices where distance bigger than threshold
        relevant_indices = jnp.where(
            values.squeeze() > self.acceptance_threshold, x=0, y=1
        )

        def iterate_fn(
            carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict
        ) -> Tuple[Tuple[Archive, jnp.ndarray, int], Any]:
            """Iterates over the archive to add elements one after the other.

            Args:
                carry: tuple containing the archive, the state descriptors and the
                    indices.

                condition_data: the first addition condition of the state descriptors
                    given, which corresponds to being sufficiently far away from already
                    stored descriptors.

            Returns:
                The update tuple.
            """
            archive, new_elements, index = carry

            first_condition = condition_data["condition"]
            state_descriptor = condition_data["state_descriptor"]

            # do the filtering among the added elements
            # get nearest neigbor for each new state descriptor
            values, _indices = knn(new_elements, state_descriptor.reshape(1, -1), 1)

            # get indices where distance bigger than threshold
            not_too_close = jnp.where(
                values.squeeze() > self.acceptance_threshold, x=0, y=1
            )
            second_condition = not_too_close.sum()
            condition = (first_condition + second_condition) == 0

            new_archive, added_element = archive._conditioned_single_insertion(
                condition, state_descriptor
            )
            new_elements = new_elements.at[index].set(added_element)
            index += 1

            return (
                (new_archive, new_elements, index),
                (),
            )

        new_elements = jnp.ones_like(state_descriptors) * jnp.nan

        # iterate over the indices
        (new_archive, _, _), _ = jax.lax.scan(
            iterate_fn,
            (self, new_elements, 0),
            {
                "condition": relevant_indices,
                "state_descriptor": state_descriptors,
            },
        )

        return new_archive  # type: ignore
size: float property readonly

Compute the number of state descriptors stored in the archive.

Returns:
  • float – Size of the archive.

create(acceptance_threshold, state_descriptor_size, max_size) classmethod

Create an Archive instance.

This class method provides a convenient way to create the archive while keeping the init function for more general way to init an archive.

Parameters:
  • acceptance_threshold (float) – the minimal distance to a stored descriptor to be respected for a new descriptor to be added.

  • state_descriptor_size (int) – the number of elements in a state descriptor.

  • max_size (int) – the maximal size of the archive. In case of overflow, previous elements are replaced by new ones. Defaults to 80000.

Returns:
  • Archive – A newly initialized archive.

Source code in qdax/core/containers/archive.py
@classmethod
def create(
    cls,
    acceptance_threshold: float,
    state_descriptor_size: int,
    max_size: int,
) -> Archive:
    """Create an Archive instance.

    This class method provides a convenient way to create the archive while
    keeping the __init__ function for more general way to init an archive.

    Args:
        acceptance_threshold: the minimal distance to a stored descriptor to
            be respected for a new descriptor to be added.
        state_descriptor_size: the number of elements in a state descriptor.
        max_size: the maximal size of the archive. In case of overflow, previous
            elements are replaced by new ones. Defaults to 80000.

    Returns:
        A newly initialized archive.
    """
    init_data = jnp.ones((max_size, state_descriptor_size)) * jnp.nan
    return cls(  # type: ignore
        data=init_data,
        current_position=0,
        acceptance_threshold=acceptance_threshold,
        state_descriptor_size=state_descriptor_size,
        max_size=max_size,
    )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/archive.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
insert(self, state_descriptors)

Tries to insert a batch of state descriptors in the archive.

  1. First, look at the distance of each new state descriptor with the already stored ones.
  2. Then, scan the state descriptors, check the distance with the descriptors inserted during the scan.
  3. If the state descriptor verified the first condition (not too close to a state descriptor in the old archive) and the second (not too close from a state descriptor that has just been added), then it is added to the archive.

Note 1: the archive has a fixed size, hence, in case of overflow, the first elements added are removed first (FIFO style). Note 2: keep in mind that fake descriptors are used to help keep the size constant. Those correspond to a descriptor very far away from the typical values of the problem at hand.

Parameters:
  • state_descriptors (jnp.ndarray) – state descriptors to be added.

Returns:
  • Archive – New archive updated with the state descriptors.

Source code in qdax/core/containers/archive.py
@jax.jit
def insert(self, state_descriptors: jnp.ndarray) -> Archive:
    """Tries to insert a batch of state descriptors in the archive.

    1. First, look at the distance of each new state descriptor with the
    already stored ones.
    2. Then, scan the state descriptors, check the distance with
    the descriptors inserted during the scan.
    3. If the state descriptor verified the first condition (not too close
    to a state descriptor in the old archive) and the second (not too close
    from a state descriptor that has just been added), then it is added
    to the archive.

    Note 1: the archive has a fixed size, hence, in case of overflow, the
    first elements added are removed first (FIFO style).
    Note 2: keep in mind that fake descriptors are used to help keep the size
    constant. Those correspond to a descriptor very far away from the typical
    values of the problem at hand.

    Args:
        state_descriptors: state descriptors to be added.

    Returns:
        New archive updated with the state descriptors.
    """
    state_descriptors = state_descriptors.reshape((-1, state_descriptors.shape[-1]))

    # get nearest neigbor for each new state descriptor
    values, _indices = knn(self.data, state_descriptors, 1)

    # get indices where distance bigger than threshold
    relevant_indices = jnp.where(
        values.squeeze() > self.acceptance_threshold, x=0, y=1
    )

    def iterate_fn(
        carry: Tuple[Archive, jnp.ndarray, int], condition_data: Dict
    ) -> Tuple[Tuple[Archive, jnp.ndarray, int], Any]:
        """Iterates over the archive to add elements one after the other.

        Args:
            carry: tuple containing the archive, the state descriptors and the
                indices.

            condition_data: the first addition condition of the state descriptors
                given, which corresponds to being sufficiently far away from already
                stored descriptors.

        Returns:
            The update tuple.
        """
        archive, new_elements, index = carry

        first_condition = condition_data["condition"]
        state_descriptor = condition_data["state_descriptor"]

        # do the filtering among the added elements
        # get nearest neigbor for each new state descriptor
        values, _indices = knn(new_elements, state_descriptor.reshape(1, -1), 1)

        # get indices where distance bigger than threshold
        not_too_close = jnp.where(
            values.squeeze() > self.acceptance_threshold, x=0, y=1
        )
        second_condition = not_too_close.sum()
        condition = (first_condition + second_condition) == 0

        new_archive, added_element = archive._conditioned_single_insertion(
            condition, state_descriptor
        )
        new_elements = new_elements.at[index].set(added_element)
        index += 1

        return (
            (new_archive, new_elements, index),
            (),
        )

    new_elements = jnp.ones_like(state_descriptors) * jnp.nan

    # iterate over the indices
    (new_archive, _, _), _ = jax.lax.scan(
        iterate_fn,
        (self, new_elements, 0),
        {
            "condition": relevant_indices,
            "state_descriptor": state_descriptors,
        },
    )

    return new_archive  # type: ignore

score_euclidean_novelty(archive, state_descriptors, num_nearest_neighb, scaling_ratio)

Scores the novelty of a jnp.ndarray with respect to the elements of an archive.

Typical use case in the construction of the diversity rewards in QDPG.

Parameters:
  • archive (Archive) – an archive of state descriptors.

  • state_descriptors (jnp.ndarray) – state descriptors which novelty must be scored.

  • num_nearest_neighb (int) – the number of nearest neighbors to be considered when scoring.

  • scaling_ratio (float) – the ratio applied to the the mean distance to obtain the final value.

Returns:
  • jnp.ndarray – The novelty scores of the given state descriptors.

Source code in qdax/core/containers/archive.py
def score_euclidean_novelty(
    archive: Archive,
    state_descriptors: jnp.ndarray,
    num_nearest_neighb: int,
    scaling_ratio: float,
) -> jnp.ndarray:
    """Scores the novelty of a jnp.ndarray with respect to the elements of an archive.

    Typical use case in the construction of the diversity rewards
    in QDPG.

    Args:
        archive: an archive of state descriptors.
        state_descriptors: state descriptors which novelty must be scored.
        num_nearest_neighb: the number of nearest neighbors to be considered
            when scoring.
        scaling_ratio: the ratio applied to the the mean distance to obtain the
            final value.

    Returns:
        The novelty scores of the given state descriptors.
    """
    values, _indices = knn(archive.data, state_descriptors, num_nearest_neighb)

    summed_distances = jnp.mean(jnp.square(values), axis=1)
    return scaling_ratio * summed_distances

knn(data, new_data, k)

K nearest neigbors - Brute force implementation. Using euclidean distance.

Code from https://www.kernel-operations.io/keops/_auto_benchmarks/ plot_benchmark_KNN.html

Parameters:
  • data (jnp.ndarray) – given reference data.

  • new_data (jnp.ndarray) – data to be compared to the reference data.

  • k (jnp.ndarray) – number of neigbors to consider.

Returns:
  • Tuple[jnp.ndarray, jnp.ndarray] – The distances and indices of the nearest neighbors.

Source code in qdax/core/containers/archive.py
@partial(jax.jit, static_argnames=("k"))
def knn(
    data: jnp.ndarray, new_data: jnp.ndarray, k: jnp.ndarray
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """K nearest neigbors - Brute force implementation.
    Using euclidean distance.

    Code from https://www.kernel-operations.io/keops/_auto_benchmarks/
    plot_benchmark_KNN.html

    Args:
        data: given reference data.
        new_data: data to be compared to the reference data.
        k: number of neigbors to consider.

    Returns:
        The distances and indices of the nearest neighbors.
    """

    # compute distances
    dist = (
        (new_data**2).sum(-1)[:, None]
        + (data**2).sum(-1)[None, :]
        - 2 * new_data @ data.T
    )

    dist = jnp.nan_to_num(dist, nan=jnp.inf)

    # clipping necessary - numerical approx make some distancies negative
    dist = jnp.sqrt(jnp.clip(dist, a_min=0.0))

    # return values, indices
    values, indices = qdax_top_k(-dist, k)

    return -values, indices

qdax_top_k(data, k)

Returns the top k elements of an array.

Interestingly, this naive implementation is faster than the native implementation of jax for small k. See issue: https://github.com/google/jax/issues/9940

Waiting for updates in jax to change this implementation.

Parameters:
  • data (jnp.ndarray) – given data.

  • k (int) – number of top elements to determine.

Returns:
  • Tuple[jnp.ndarray, jnp.ndarray] – The values of the elements and their indices in the array.

Source code in qdax/core/containers/archive.py
@partial(jax.jit, static_argnames=("k"))
def qdax_top_k(data: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Returns the top k elements of an array.

    Interestingly, this naive implementation is faster than the native implementation
    of jax for small k. See issue: https://github.com/google/jax/issues/9940

    Waiting for updates in jax to change this implementation.

    Args:
        data: given data.
        k: number of top elements to determine.

    Returns:
        The values of the elements and their indices in the array.
    """

    def top_1(data: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        indice = jnp.argmax(data, axis=1)
        value = jax.vmap(lambda x, y: x[y])(data, indice)
        data = jax.vmap(lambda x, y: x.at[y].set(-jnp.inf))(data, indice)

        return data, value, indice

    def scannable_top_1(
        carry: jnp.ndarray, unused: Any
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
        data = carry
        data, value, indice = top_1(data)
        return data, (value, indice)

    data, (values, indices) = jax.lax.scan(scannable_top_1, data, (), k)

    return values.T, indices.T

ga_repertoire

Defines a repertoire for simple genetic algorithms.

GARepertoire (Repertoire) dataclass

Class for a simple repertoire for a simple genetic algorithm.

Parameters:
  • genotypes (Genotype) – a PyTree containing the genotypes of the individuals in the population. Each leaf has the shape (population_size, num_features).

  • fitnesses (Fitness) – an array containing the fitness of the individuals in the population. With shape (population_size, fitness_dim). The implementation of GARepertoire was thought for the case where fitness_dim equals 1 but the class can be herited and rules adapted for cases where fitness_dim is greater than 1.

Source code in qdax/core/containers/ga_repertoire.py
class GARepertoire(Repertoire):
    """Class for a simple repertoire for a simple genetic
    algorithm.

    Args:
        genotypes: a PyTree containing the genotypes of the
            individuals in the population. Each leaf has the
            shape (population_size, num_features).
        fitnesses: an array containing the fitness of the individuals
            in the population. With shape (population_size, fitness_dim).
            The implementation of GARepertoire was thought for the case
            where fitness_dim equals 1 but the class can be herited and
            rules adapted for cases where fitness_dim is greater than 1.
    """

    genotypes: Genotype
    fitnesses: Fitness

    @property
    def size(self) -> int:
        """Gives the size of the population."""
        first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
        return int(first_leaf.shape[0])

    def save(self, path: str = "./") -> None:
        """Saves the repertoire.

        Args:
            path: place to store the files. Defaults to "./".
        """

        def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
            flatten_genotype, _ = ravel_pytree(genotype)
            return flatten_genotype

        # flatten all the genotypes
        flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

        jnp.save(path + "genotypes.npy", flat_genotypes)
        jnp.save(path + "scores.npy", self.fitnesses)

    @classmethod
    def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
        """Loads a GA Repertoire.

        Args:
            reconstruction_fn: Function to reconstruct a PyTree
                from a flat array.
            path: Path where the data is saved. Defaults to "./".

        Returns:
            A GA Repertoire.
        """

        flat_genotypes = jnp.load(path + "genotypes.npy")
        genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

        fitnesses = jnp.load(path + "fitnesses.npy")

        return cls(
            genotypes=genotypes,
            fitnesses=fitnesses,
        )

    @partial(jax.jit, static_argnames=("num_samples",))
    def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
        """Sample genotypes from the repertoire.

        Args:
            random_key: a random key to handle stochasticity.
            num_samples: the number of genotypes to sample.

        Returns:
            The sample of genotypes.
        """

        # prepare sampling probability
        mask = self.fitnesses != -jnp.inf
        p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1))

        # sample
        random_key, subkey = jax.random.split(random_key)
        samples = jax.tree_util.tree_map(
            lambda x: jax.random.choice(
                subkey, x, shape=(num_samples,), p=p, replace=False
            ),
            self.genotypes,
        )

        return samples, random_key

    @jax.jit
    def add(
        self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
    ) -> GARepertoire:
        """Implements the repertoire addition rules.

        Parents and offsprings are gathered and only the population_size
        bests are kept. The others are killed.

        Args:
            batch_of_genotypes: new genotypes that we try to add.
            batch_of_fitnesses: fitness of those new genotypes.

        Returns:
            The updated repertoire.
        """

        # gather individuals and fitnesses
        candidates = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )
        candidates_fitnesses = jnp.concatenate(
            (self.fitnesses, batch_of_fitnesses), axis=0
        )

        # sort by fitnesses
        indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]

        # keep only the best ones
        survivor_indices = indices[: self.size]

        # keep only the best ones
        new_candidates = jax.tree_util.tree_map(
            lambda x: x[survivor_indices], candidates
        )

        new_repertoire = self.replace(
            genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
        )

        return new_repertoire  # type: ignore

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        population_size: int,
    ) -> GARepertoire:
        """Initializes the repertoire.

        Start with default values and adds a first batch of genotypes
        to the repertoire.

        Args:
            genotypes: first batch of genotypes
            fitnesses: corresponding fitnesses
            population_size: size of the population we want to evolve

        Returns:
            An initial repertoire.
        """
        # create default fitnesses
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(population_size, fitnesses.shape[-1])
        )

        # create default genotypes
        default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
        )

        # create an initial repertoire with those default values
        repertoire = cls(genotypes=default_genotypes, fitnesses=default_fitnesses)

        new_repertoire = repertoire.add(genotypes, fitnesses)

        return new_repertoire  # type: ignore
size: int property readonly

Gives the size of the population.

save(self, path='./')

Saves the repertoire.

Parameters:
  • path (str) – place to store the files. Defaults to "./".

Source code in qdax/core/containers/ga_repertoire.py
def save(self, path: str = "./") -> None:
    """Saves the repertoire.

    Args:
        path: place to store the files. Defaults to "./".
    """

    def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
        flatten_genotype, _ = ravel_pytree(genotype)
        return flatten_genotype

    # flatten all the genotypes
    flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

    jnp.save(path + "genotypes.npy", flat_genotypes)
    jnp.save(path + "scores.npy", self.fitnesses)
load(reconstruction_fn, path='./') classmethod

Loads a GA Repertoire.

Parameters:
  • reconstruction_fn (Callable) – Function to reconstruct a PyTree from a flat array.

  • path (str) – Path where the data is saved. Defaults to "./".

Returns:
  • GARepertoire – A GA Repertoire.

Source code in qdax/core/containers/ga_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
    """Loads a GA Repertoire.

    Args:
        reconstruction_fn: Function to reconstruct a PyTree
            from a flat array.
        path: Path where the data is saved. Defaults to "./".

    Returns:
        A GA Repertoire.
    """

    flat_genotypes = jnp.load(path + "genotypes.npy")
    genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

    fitnesses = jnp.load(path + "fitnesses.npy")

    return cls(
        genotypes=genotypes,
        fitnesses=fitnesses,
    )
sample(self, random_key, num_samples)

Sample genotypes from the repertoire.

Parameters:
  • random_key (RNGKey) – a random key to handle stochasticity.

  • num_samples (int) – the number of genotypes to sample.

Returns:
  • Tuple[Genotype, RNGKey] – The sample of genotypes.

Source code in qdax/core/containers/ga_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
    """Sample genotypes from the repertoire.

    Args:
        random_key: a random key to handle stochasticity.
        num_samples: the number of genotypes to sample.

    Returns:
        The sample of genotypes.
    """

    # prepare sampling probability
    mask = self.fitnesses != -jnp.inf
    p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1))

    # sample
    random_key, subkey = jax.random.split(random_key)
    samples = jax.tree_util.tree_map(
        lambda x: jax.random.choice(
            subkey, x, shape=(num_samples,), p=p, replace=False
        ),
        self.genotypes,
    )

    return samples, random_key
add(self, batch_of_genotypes, batch_of_fitnesses)

Implements the repertoire addition rules.

Parents and offsprings are gathered and only the population_size bests are kept. The others are killed.

Parameters:
  • batch_of_genotypes (Genotype) – new genotypes that we try to add.

  • batch_of_fitnesses (Fitness) – fitness of those new genotypes.

Returns:
  • GARepertoire – The updated repertoire.

Source code in qdax/core/containers/ga_repertoire.py
@jax.jit
def add(
    self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> GARepertoire:
    """Implements the repertoire addition rules.

    Parents and offsprings are gathered and only the population_size
    bests are kept. The others are killed.

    Args:
        batch_of_genotypes: new genotypes that we try to add.
        batch_of_fitnesses: fitness of those new genotypes.

    Returns:
        The updated repertoire.
    """

    # gather individuals and fitnesses
    candidates = jax.tree_util.tree_map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )
    candidates_fitnesses = jnp.concatenate(
        (self.fitnesses, batch_of_fitnesses), axis=0
    )

    # sort by fitnesses
    indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]

    # keep only the best ones
    survivor_indices = indices[: self.size]

    # keep only the best ones
    new_candidates = jax.tree_util.tree_map(
        lambda x: x[survivor_indices], candidates
    )

    new_repertoire = self.replace(
        genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
    )

    return new_repertoire  # type: ignore
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/ga_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
init(genotypes, fitnesses, population_size) classmethod

Initializes the repertoire.

Start with default values and adds a first batch of genotypes to the repertoire.

Parameters:
  • genotypes (Genotype) – first batch of genotypes

  • fitnesses (Fitness) – corresponding fitnesses

  • population_size (int) – size of the population we want to evolve

Returns:
  • GARepertoire – An initial repertoire.

Source code in qdax/core/containers/ga_repertoire.py
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    population_size: int,
) -> GARepertoire:
    """Initializes the repertoire.

    Start with default values and adds a first batch of genotypes
    to the repertoire.

    Args:
        genotypes: first batch of genotypes
        fitnesses: corresponding fitnesses
        population_size: size of the population we want to evolve

    Returns:
        An initial repertoire.
    """
    # create default fitnesses
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(population_size, fitnesses.shape[-1])
    )

    # create default genotypes
    default_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
    )

    # create an initial repertoire with those default values
    repertoire = cls(genotypes=default_genotypes, fitnesses=default_fitnesses)

    new_repertoire = repertoire.add(genotypes, fitnesses)

    return new_repertoire  # type: ignore

mapelites_repertoire

This file contains util functions and a class to define a repertoire, used to store individuals in the MAP-Elites algorithm as well as several variants.

MapElitesRepertoire (PyTreeNode) dataclass

Class for the repertoire in Map Elites.

Parameters:
  • genotypes (Genotype) – a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple Jax array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses (Fitness) – an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors (Descriptor) – an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • centroids (Centroid) – an array that contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors).

Source code in qdax/core/containers/mapelites_repertoire.py
class MapElitesRepertoire(flax.struct.PyTreeNode):
    """Class for the repertoire in Map Elites.

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple Jax array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        centroids: an array that contains the centroids of the tessellation. The array
            shape is (num_centroids, num_descriptors).
    """

    genotypes: Genotype
    fitnesses: Fitness
    descriptors: Descriptor
    centroids: Centroid

    def save(self, path: str = "./") -> None:
        """Saves the repertoire on disk in the form of .npy files.

        Flattens the genotypes to store it with .npy format. Supposes that
        a user will have access to the reconstruction function when loading
        the genotypes.

        Args:
            path: Path where the data will be saved. Defaults to "./".
        """

        def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
            flatten_genotype, _ = ravel_pytree(genotype)
            return flatten_genotype

        # flatten all the genotypes
        flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

        # save data
        jnp.save(path + "genotypes.npy", flat_genotypes)
        jnp.save(path + "fitnesses.npy", self.fitnesses)
        jnp.save(path + "descriptors.npy", self.descriptors)
        jnp.save(path + "centroids.npy", self.centroids)

    @classmethod
    def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesRepertoire:
        """Loads a MAP Elites Repertoire.

        Args:
            reconstruction_fn: Function to reconstruct a PyTree
                from a flat array.
            path: Path where the data is saved. Defaults to "./".

        Returns:
            A MAP Elites Repertoire.
        """

        flat_genotypes = jnp.load(path + "genotypes.npy")
        genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

        fitnesses = jnp.load(path + "fitnesses.npy")
        descriptors = jnp.load(path + "descriptors.npy")
        centroids = jnp.load(path + "centroids.npy")

        return cls(
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
        )

    @partial(jax.jit, static_argnames=("num_samples",))
    def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
        """Sample elements in the repertoire.

        Args:
            random_key: a jax PRNG random key
            num_samples: the number of elements to be sampled

        Returns:
            samples: a batch of genotypes sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        repertoire_empty = self.fitnesses == -jnp.inf
        p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

        random_key, subkey = jax.random.split(random_key)
        samples = jax.tree_util.tree_map(
            lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
            self.genotypes,
        )

        return samples, random_key

    @jax.jit
    def add(
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MapElitesRepertoire:
        """
        Add a batch of elements to the repertoire.

        Args:
            batch_of_genotypes: a batch of genotypes to be added to the repertoire.
                Similarly to the self.genotypes argument, this is a PyTree in which
                the leaves have a shape (batch_size, num_features)
            batch_of_descriptors: an array that contains the descriptors of the
                aforementioned genotypes. Its shape is (batch_size, num_descriptors)
            batch_of_fitnesses: an array that contains the fitnesses of the
                aforementioned genotypes. Its shape is (batch_size,)
            batch_of_extra_scores: unused tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated MAP-Elites repertoire.
        """

        batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
        batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)

        num_centroids = self.centroids.shape[0]

        # get fitness segment max
        best_fitnesses = jax.ops.segment_max(
            batch_of_fitnesses,
            batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
            num_segments=num_centroids,
        )

        cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

        # put dominated fitness to -jnp.inf
        batch_of_fitnesses = jnp.where(
            batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
        )

        # get addition condition
        repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
        current_fitnesses = jnp.take_along_axis(
            repertoire_fitnesses, batch_of_indices, 0
        )
        addition_condition = batch_of_fitnesses > current_fitnesses

        # assign fake position when relevant : num_centroids is out of bound
        batch_of_indices = jnp.where(
            addition_condition, x=batch_of_indices, y=num_centroids
        )

        # create new repertoire
        new_repertoire_genotypes = jax.tree_util.tree_map(
            lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_fitnesses.squeeze(axis=-1)
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_descriptors
        )

        return MapElitesRepertoire(
            genotypes=new_repertoire_genotypes,
            fitnesses=new_fitnesses,
            descriptors=new_descriptors,
            centroids=self.centroids,
        )

    @classmethod
    def init(
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        centroids: Centroid,
        extra_scores: Optional[ExtraScores] = None,
    ) -> MapElitesRepertoire:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so it can
        be called easily called from other modules.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape (batch_size,)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            centroids: tesselation centroids of shape (batch_size, num_descriptors)
            extra_scores: unused extra_scores of the initial genotypes

        Returns:
            an initialized MAP-Elite repertoire
        """
        warnings.warn(
            (
                "This type of repertoire does not store the extra scores "
                "computed by the scoring function"
            ),
            stacklevel=2,
        )

        # retrieve one genotype from the population
        first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes)

        # create a repertoire with default values
        repertoire = cls.init_default(genotype=first_genotype, centroids=centroids)

        # add initial population to the repertoire
        new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

    @classmethod
    def init_default(
        cls,
        genotype: Genotype,
        centroids: Centroid,
    ) -> MapElitesRepertoire:
        """Initialize a Map-Elites repertoire with an initial population of
        genotypes. Requires the definition of centroids that can be computed
        with any method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so
        it can be called easily called from other modules.

        Args:
            genotype: the typical genotype that will be stored.
            centroids: the centroids of the repertoire

        Returns:
            A repertoire filled with default values.
        """

        # get number of centroids
        num_centroids = centroids.shape[0]

        # default fitness is -inf
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        # default genotypes is all 0
        default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            genotype,
        )

        # default descriptor is all zeros
        default_descriptors = jnp.zeros_like(centroids)

        return cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
        )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/mapelites_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
save(self, path='./')

Saves the repertoire on disk in the form of .npy files.

Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.

Parameters:
  • path (str) – Path where the data will be saved. Defaults to "./".

Source code in qdax/core/containers/mapelites_repertoire.py
def save(self, path: str = "./") -> None:
    """Saves the repertoire on disk in the form of .npy files.

    Flattens the genotypes to store it with .npy format. Supposes that
    a user will have access to the reconstruction function when loading
    the genotypes.

    Args:
        path: Path where the data will be saved. Defaults to "./".
    """

    def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
        flatten_genotype, _ = ravel_pytree(genotype)
        return flatten_genotype

    # flatten all the genotypes
    flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

    # save data
    jnp.save(path + "genotypes.npy", flat_genotypes)
    jnp.save(path + "fitnesses.npy", self.fitnesses)
    jnp.save(path + "descriptors.npy", self.descriptors)
    jnp.save(path + "centroids.npy", self.centroids)
load(reconstruction_fn, path='./') classmethod

Loads a MAP Elites Repertoire.

Parameters:
  • reconstruction_fn (Callable) – Function to reconstruct a PyTree from a flat array.

  • path (str) – Path where the data is saved. Defaults to "./".

Returns:
  • MapElitesRepertoire – A MAP Elites Repertoire.

Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MapElitesRepertoire:
    """Loads a MAP Elites Repertoire.

    Args:
        reconstruction_fn: Function to reconstruct a PyTree
            from a flat array.
        path: Path where the data is saved. Defaults to "./".

    Returns:
        A MAP Elites Repertoire.
    """

    flat_genotypes = jnp.load(path + "genotypes.npy")
    genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

    fitnesses = jnp.load(path + "fitnesses.npy")
    descriptors = jnp.load(path + "descriptors.npy")
    centroids = jnp.load(path + "centroids.npy")

    return cls(
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
    )
sample(self, random_key, num_samples)

Sample elements in the repertoire.

Parameters:
  • random_key (RNGKey) – a jax PRNG random key

  • num_samples (int) – the number of elements to be sampled

Returns:
  • samples – a batch of genotypes sampled in the repertoire random_key: an updated jax PRNG random key

Source code in qdax/core/containers/mapelites_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
    """Sample elements in the repertoire.

    Args:
        random_key: a jax PRNG random key
        num_samples: the number of elements to be sampled

    Returns:
        samples: a batch of genotypes sampled in the repertoire
        random_key: an updated jax PRNG random key
    """

    repertoire_empty = self.fitnesses == -jnp.inf
    p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

    random_key, subkey = jax.random.split(random_key)
    samples = jax.tree_util.tree_map(
        lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
        self.genotypes,
    )

    return samples, random_key
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Add a batch of elements to the repertoire.

Parameters:
  • batch_of_genotypes (Genotype) – a batch of genotypes to be added to the repertoire. Similarly to the self.genotypes argument, this is a PyTree in which the leaves have a shape (batch_size, num_features)

  • batch_of_descriptors (Descriptor) – an array that contains the descriptors of the aforementioned genotypes. Its shape is (batch_size, num_descriptors)

  • batch_of_fitnesses (Fitness) – an array that contains the fitnesses of the aforementioned genotypes. Its shape is (batch_size,)

  • batch_of_extra_scores (Optional[ExtraScores]) – unused tree that contains the extra_scores of aforementioned genotypes.

Returns:
  • MapElitesRepertoire – The updated MAP-Elites repertoire.

Source code in qdax/core/containers/mapelites_repertoire.py
@jax.jit
def add(
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
    """
    Add a batch of elements to the repertoire.

    Args:
        batch_of_genotypes: a batch of genotypes to be added to the repertoire.
            Similarly to the self.genotypes argument, this is a PyTree in which
            the leaves have a shape (batch_size, num_features)
        batch_of_descriptors: an array that contains the descriptors of the
            aforementioned genotypes. Its shape is (batch_size, num_descriptors)
        batch_of_fitnesses: an array that contains the fitnesses of the
            aforementioned genotypes. Its shape is (batch_size,)
        batch_of_extra_scores: unused tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated MAP-Elites repertoire.
    """

    batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
    batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)

    num_centroids = self.centroids.shape[0]

    # get fitness segment max
    best_fitnesses = jax.ops.segment_max(
        batch_of_fitnesses,
        batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
        num_segments=num_centroids,
    )

    cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

    # put dominated fitness to -jnp.inf
    batch_of_fitnesses = jnp.where(
        batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
    )

    # get addition condition
    repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
    current_fitnesses = jnp.take_along_axis(
        repertoire_fitnesses, batch_of_indices, 0
    )
    addition_condition = batch_of_fitnesses > current_fitnesses

    # assign fake position when relevant : num_centroids is out of bound
    batch_of_indices = jnp.where(
        addition_condition, x=batch_of_indices, y=num_centroids
    )

    # create new repertoire
    new_repertoire_genotypes = jax.tree_util.tree_map(
        lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_fitnesses.squeeze(axis=-1)
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_descriptors
    )

    return MapElitesRepertoire(
        genotypes=new_repertoire_genotypes,
        fitnesses=new_fitnesses,
        descriptors=new_descriptors,
        centroids=self.centroids,
    )
init(genotypes, fitnesses, descriptors, centroids, extra_scores=None) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) – fitness of the initial genotypes of shape (batch_size,)

  • descriptors (Descriptor) – descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • centroids (Centroid) – tesselation centroids of shape (batch_size, num_descriptors)

  • extra_scores (Optional[ExtraScores]) – unused extra_scores of the initial genotypes

Returns:
  • MapElitesRepertoire – an initialized MAP-Elite repertoire

Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def init(
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    centroids: Centroid,
    extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so it can
    be called easily called from other modules.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape (batch_size,)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        centroids: tesselation centroids of shape (batch_size, num_descriptors)
        extra_scores: unused extra_scores of the initial genotypes

    Returns:
        an initialized MAP-Elite repertoire
    """
    warnings.warn(
        (
            "This type of repertoire does not store the extra scores "
            "computed by the scoring function"
        ),
        stacklevel=2,
    )

    # retrieve one genotype from the population
    first_genotype = jax.tree_util.tree_map(lambda x: x[0], genotypes)

    # create a repertoire with default values
    repertoire = cls.init_default(genotype=first_genotype, centroids=centroids)

    # add initial population to the repertoire
    new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    return new_repertoire  # type: ignore
init_default(genotype, centroids) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotype (Genotype) – the typical genotype that will be stored.

  • centroids (Centroid) – the centroids of the repertoire

Returns:
  • MapElitesRepertoire – A repertoire filled with default values.

Source code in qdax/core/containers/mapelites_repertoire.py
@classmethod
def init_default(
    cls,
    genotype: Genotype,
    centroids: Centroid,
) -> MapElitesRepertoire:
    """Initialize a Map-Elites repertoire with an initial population of
    genotypes. Requires the definition of centroids that can be computed
    with any method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so
    it can be called easily called from other modules.

    Args:
        genotype: the typical genotype that will be stored.
        centroids: the centroids of the repertoire

    Returns:
        A repertoire filled with default values.
    """

    # get number of centroids
    num_centroids = centroids.shape[0]

    # default fitness is -inf
    default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

    # default genotypes is all 0
    default_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        genotype,
    )

    # default descriptor is all zeros
    default_descriptors = jnp.zeros_like(centroids)

    return cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
    )

compute_cvt_centroids(num_descriptors, num_init_cvt_samples, num_centroids, minval, maxval, random_key)

Compute centroids for CVT tessellation.

Parameters:
  • num_descriptors (int) – number of scalar descriptors

  • num_init_cvt_samples (int) – number of sampled point to be sued for clustering to determine the centroids. The larger the number of centroids and the number of descriptors, the higher this value must be (e.g. 100000 for 1024 centroids and 100 descriptors).

  • num_centroids (int) – number of centroids

  • minval (Union[float, List[float]]) – minimum descriptors value

  • maxval (Union[float, List[float]]) – maximum descriptors value

  • random_key (RNGKey) – a jax PRNG random key

Returns:
  • the centroids with shape (num_centroids, num_descriptors) random_key – an updated jax PRNG random key

Source code in qdax/core/containers/mapelites_repertoire.py
def compute_cvt_centroids(
    num_descriptors: int,
    num_init_cvt_samples: int,
    num_centroids: int,
    minval: Union[float, List[float]],
    maxval: Union[float, List[float]],
    random_key: RNGKey,
) -> Tuple[jnp.ndarray, RNGKey]:
    """Compute centroids for CVT tessellation.

    Args:
        num_descriptors: number of scalar descriptors
        num_init_cvt_samples: number of sampled point to be sued for clustering to
            determine the centroids. The larger the number of centroids and the
            number of descriptors, the higher this value must be (e.g. 100000 for
            1024 centroids and 100 descriptors).
        num_centroids: number of centroids
        minval: minimum descriptors value
        maxval: maximum descriptors value
        random_key: a jax PRNG random key

    Returns:
        the centroids with shape (num_centroids, num_descriptors)
        random_key: an updated jax PRNG random key
    """
    minval = jnp.array(minval)
    maxval = jnp.array(maxval)

    # assume here all values are in [0, 1] and rescale later
    random_key, subkey = jax.random.split(random_key)
    x = jax.random.uniform(key=subkey, shape=(num_init_cvt_samples, num_descriptors))

    # compute k means
    random_key, subkey = jax.random.split(random_key)
    k_means = KMeans(
        init="k-means++",
        n_clusters=num_centroids,
        n_init=1,
        random_state=RandomState(subkey),
    )
    k_means.fit(x)
    centroids = k_means.cluster_centers_
    # rescale now
    return jnp.asarray(centroids) * (maxval - minval) + minval, random_key

compute_euclidean_centroids(grid_shape, minval, maxval)

Compute centroids for square Euclidean tessellation.

Parameters:
  • grid_shape (Tuple[int, ...]) – number of centroids per BD dimension

  • minval (Union[float, List[float]]) – minimum descriptors value

  • maxval (Union[float, List[float]]) – maximum descriptors value

Returns:
  • jnp.ndarray – the centroids with shape (num_centroids, num_descriptors)

Source code in qdax/core/containers/mapelites_repertoire.py
def compute_euclidean_centroids(
    grid_shape: Tuple[int, ...],
    minval: Union[float, List[float]],
    maxval: Union[float, List[float]],
) -> jnp.ndarray:
    """Compute centroids for square Euclidean tessellation.

    Args:
        grid_shape: number of centroids per BD dimension
        minval: minimum descriptors value
        maxval: maximum descriptors value

    Returns:
        the centroids with shape (num_centroids, num_descriptors)
    """
    # get number of descriptors
    num_descriptors = len(grid_shape)

    # prepare list of linspaces
    linspace_list = []
    for num_centroids_in_dim in grid_shape:
        offset = 1 / (2 * num_centroids_in_dim)
        linspace = jnp.linspace(offset, 1.0 - offset, num_centroids_in_dim)
        linspace_list.append(linspace)

    meshes = jnp.meshgrid(*linspace_list, sparse=False)

    # create centroids
    centroids = jnp.stack(
        [jnp.ravel(meshes[i]) for i in range(num_descriptors)], axis=-1
    )
    minval = jnp.array(minval)
    maxval = jnp.array(maxval)
    return jnp.asarray(centroids) * (maxval - minval) + minval

get_cells_indices(batch_of_descriptors, centroids)

Returns the array of cells indices for a batch of descriptors given the centroids of the repertoire.

Parameters:
  • batch_of_descriptors (jnp.ndarray) – a batch of descriptors of shape (batch_size, num_descriptors)

  • centroids (jnp.ndarray) – centroids array of shape (num_centroids, num_descriptors)

Returns:
  • jnp.ndarray – the indices of the centroids corresponding to each vector of descriptors in the batch with shape (batch_size,)

Source code in qdax/core/containers/mapelites_repertoire.py
def get_cells_indices(
    batch_of_descriptors: jnp.ndarray, centroids: jnp.ndarray
) -> jnp.ndarray:
    """
    Returns the array of cells indices for a batch of descriptors
    given the centroids of the repertoire.

    Args:
        batch_of_descriptors: a batch of descriptors
            of shape (batch_size, num_descriptors)
        centroids: centroids array of shape (num_centroids, num_descriptors)

    Returns:
        the indices of the centroids corresponding to each vector of descriptors
            in the batch with shape (batch_size,)
    """

    def _get_cells_indices(
        descriptors: jnp.ndarray, centroids: jnp.ndarray
    ) -> jnp.ndarray:
        """Set_of_descriptors of shape (1, num_descriptors)
        centroids of shape (num_centroids, num_descriptors)
        """
        return jnp.argmin(
            jnp.sum(jnp.square(jnp.subtract(descriptors, centroids)), axis=-1)
        )

    func = jax.vmap(lambda x: _get_cells_indices(x, centroids))
    return func(batch_of_descriptors)

mels_repertoire

This file contains the class to define the repertoire used to store individuals in the Multi-Objective MAP-Elites algorithm as well as several variants.

MELSRepertoire (MapElitesRepertoire) dataclass

Class for the repertoire in MAP-Elites Low-Spread.

This class inherits from MapElitesRepertoire. In addition to the stored data in MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire also maintains an array of spreads. We overload the save, load, add, and init_default methods of MapElitesRepertoire.

Refer to Mace 2023 for more info on MAP-Elites Low-Spread: https://dl.acm.org/doi/abs/10.1145/3583131.3590433

Parameters:
  • genotypes (Genotype) – a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple Jax array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses (Fitness) – an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors (Descriptor) – an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • centroids (Centroid) – an array that contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors).

  • spreads (Spread) – an array that contains the spread of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

Source code in qdax/core/containers/mels_repertoire.py
class MELSRepertoire(MapElitesRepertoire):
    """Class for the repertoire in MAP-Elites Low-Spread.

    This class inherits from MapElitesRepertoire. In addition to the stored data in
    MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire
    also maintains an array of spreads. We overload the save, load, add, and
    init_default methods of MapElitesRepertoire.

    Refer to Mace 2023 for more info on MAP-Elites Low-Spread:
    https://dl.acm.org/doi/abs/10.1145/3583131.3590433

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple Jax array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        centroids: an array that contains the centroids of the tessellation. The array
            shape is (num_centroids, num_descriptors).
        spreads: an array that contains the spread of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
    """

    spreads: Spread

    def save(self, path: str = "./") -> None:
        """Saves the repertoire on disk in the form of .npy files.

        Flattens the genotypes to store it with .npy format. Supposes that
        a user will have access to the reconstruction function when loading
        the genotypes.

        Args:
            path: Path where the data will be saved. Defaults to "./".
        """

        def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
            flatten_genotype, _ = ravel_pytree(genotype)
            return flatten_genotype

        # flatten all the genotypes
        flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

        # save data
        jnp.save(path + "genotypes.npy", flat_genotypes)
        jnp.save(path + "fitnesses.npy", self.fitnesses)
        jnp.save(path + "descriptors.npy", self.descriptors)
        jnp.save(path + "centroids.npy", self.centroids)
        jnp.save(path + "spreads.npy", self.spreads)

    @classmethod
    def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire:
        """Loads a MAP-Elites Low-Spread Repertoire.

        Args:
            reconstruction_fn: Function to reconstruct a PyTree
                from a flat array.
            path: Path where the data is saved. Defaults to "./".

        Returns:
            A MAP-Elites Low-Spread Repertoire.
        """

        flat_genotypes = jnp.load(path + "genotypes.npy")
        genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

        fitnesses = jnp.load(path + "fitnesses.npy")
        descriptors = jnp.load(path + "descriptors.npy")
        centroids = jnp.load(path + "centroids.npy")
        spreads = jnp.load(path + "spreads.npy")

        return cls(
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            spreads=spreads,
        )

    @jax.jit
    def add(
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MELSRepertoire:
        """
        Add a batch of elements to the repertoire.

        The key difference between this method and the default add() in
        MapElitesRepertoire is that it expects each individual to be evaluated
        `num_samples` times, resulting in `num_samples` fitnesses and
        `num_samples` descriptors per individual.

        If multiple individuals may be added to a single cell, this method will
        arbitrarily pick one -- the exact choice depends on the implementation of
        jax.at[].set(), which can be non-deterministic:
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
        We do not currently check if one of the multiple individuals dominates the
        others (dominate means that the individual has both highest fitness and lowest
        spread among the individuals for that cell).

        If `num_samples` is only 1, the spreads will default to 0.

        Args:
            batch_of_genotypes: a batch of genotypes to be added to the repertoire.
                Similarly to the self.genotypes argument, this is a PyTree in which
                the leaves have a shape (batch_size, num_features)
            batch_of_descriptors: an array that contains the descriptors of the
                aforementioned genotypes over all evals. Its shape is
                (batch_size, num_samples, num_descriptors). Note that we "aggregate"
                descriptors by finding the most frequent cell of each individual. Thus,
                the actual descriptors stored in the repertoire are just the coordinates
                of the centroid of the most frequent cell.
            batch_of_fitnesses: an array that contains the fitnesses of the
                aforementioned genotypes over all evals. Its shape is (batch_size,
                num_samples)
            batch_of_extra_scores: unused tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated repertoire.
        """
        batch_size, num_samples = batch_of_fitnesses.shape

        # Compute indices/cells of all descriptors.
        batch_of_all_indices = get_cells_indices(
            batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
        ).reshape((batch_size, num_samples))

        # Compute most frequent cell of each solution.
        batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]

        # Compute dispersion / spread. The dispersion is set to zero if
        # num_samples is 1.
        batch_of_spreads = jax.lax.cond(
            num_samples == 1,
            lambda desc: jnp.zeros(batch_size),
            lambda desc: jax.vmap(_dispersion)(
                desc.reshape((batch_size, num_samples, -1))
            ),
            batch_of_descriptors,
        )
        batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)

        # Compute canonical descriptors as the descriptor of the centroid of the most
        # frequent cell. Note that this line redefines the earlier batch_of_descriptors.
        batch_of_descriptors = jnp.take_along_axis(
            self.centroids, batch_of_indices, axis=0
        )

        # Compute canonical fitnesses as the average fitness.
        #
        # Shape: (batch_size, 1)
        batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)

        num_centroids = self.centroids.shape[0]

        # get current repertoire fitnesses and spreads
        repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
        current_fitnesses = jnp.take_along_axis(
            repertoire_fitnesses, batch_of_indices, 0
        )

        repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
        current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)

        # get addition condition
        addition_condition_fitness = batch_of_fitnesses > current_fitnesses
        addition_condition_spread = batch_of_spreads <= current_spreads
        addition_condition = jnp.logical_and(
            addition_condition_fitness, addition_condition_spread
        )

        # assign fake position when relevant : num_centroids is out of bound
        batch_of_indices = jnp.where(
            addition_condition, x=batch_of_indices, y=num_centroids
        )

        # create new repertoire
        new_repertoire_genotypes = jax.tree_util.tree_map(
            lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_fitnesses.squeeze(axis=-1)
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_descriptors
        )
        new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_spreads.squeeze(axis=-1)
        )

        return MELSRepertoire(
            genotypes=new_repertoire_genotypes,
            fitnesses=new_fitnesses,
            descriptors=new_descriptors,
            centroids=self.centroids,
            spreads=new_spreads,
        )

    @classmethod
    def init_default(
        cls,
        genotype: Genotype,
        centroids: Centroid,
    ) -> MELSRepertoire:
        """Initialize a MAP-Elites Low-Spread repertoire with an initial population of
        genotypes. Requires the definition of centroids that can be computed with any
        method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MELS, so
        it can be called easily called from other modules.

        Args:
            genotype: the typical genotype that will be stored.
            centroids: the centroids of the repertoire.

        Returns:
            A repertoire filled with default values.
        """

        # get number of centroids
        num_centroids = centroids.shape[0]

        # default fitness is -inf
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        # default genotypes is all 0
        default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            genotype,
        )

        # default descriptor is all zeros
        default_descriptors = jnp.zeros_like(centroids)

        # default spread is inf so that any spread will be less
        default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)

        return cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
            spreads=default_spreads,
        )
save(self, path='./')

Saves the repertoire on disk in the form of .npy files.

Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.

Parameters:
  • path (str) – Path where the data will be saved. Defaults to "./".

Source code in qdax/core/containers/mels_repertoire.py
def save(self, path: str = "./") -> None:
    """Saves the repertoire on disk in the form of .npy files.

    Flattens the genotypes to store it with .npy format. Supposes that
    a user will have access to the reconstruction function when loading
    the genotypes.

    Args:
        path: Path where the data will be saved. Defaults to "./".
    """

    def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
        flatten_genotype, _ = ravel_pytree(genotype)
        return flatten_genotype

    # flatten all the genotypes
    flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

    # save data
    jnp.save(path + "genotypes.npy", flat_genotypes)
    jnp.save(path + "fitnesses.npy", self.fitnesses)
    jnp.save(path + "descriptors.npy", self.descriptors)
    jnp.save(path + "centroids.npy", self.centroids)
    jnp.save(path + "spreads.npy", self.spreads)
load(reconstruction_fn, path='./') classmethod

Loads a MAP-Elites Low-Spread Repertoire.

Parameters:
  • reconstruction_fn (Callable) – Function to reconstruct a PyTree from a flat array.

  • path (str) – Path where the data is saved. Defaults to "./".

Returns:
  • MELSRepertoire – A MAP-Elites Low-Spread Repertoire.

Source code in qdax/core/containers/mels_repertoire.py
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> MELSRepertoire:
    """Loads a MAP-Elites Low-Spread Repertoire.

    Args:
        reconstruction_fn: Function to reconstruct a PyTree
            from a flat array.
        path: Path where the data is saved. Defaults to "./".

    Returns:
        A MAP-Elites Low-Spread Repertoire.
    """

    flat_genotypes = jnp.load(path + "genotypes.npy")
    genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

    fitnesses = jnp.load(path + "fitnesses.npy")
    descriptors = jnp.load(path + "descriptors.npy")
    centroids = jnp.load(path + "centroids.npy")
    spreads = jnp.load(path + "spreads.npy")

    return cls(
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
        spreads=spreads,
    )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/mels_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Add a batch of elements to the repertoire.

The key difference between this method and the default add() in MapElitesRepertoire is that it expects each individual to be evaluated num_samples times, resulting in num_samples fitnesses and num_samples descriptors per individual.

If multiple individuals may be added to a single cell, this method will arbitrarily pick one -- the exact choice depends on the implementation of jax.at[].set(), which can be non-deterministic: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html We do not currently check if one of the multiple individuals dominates the others (dominate means that the individual has both highest fitness and lowest spread among the individuals for that cell).

If num_samples is only 1, the spreads will default to 0.

Parameters:
  • batch_of_genotypes (Genotype) – a batch of genotypes to be added to the repertoire. Similarly to the self.genotypes argument, this is a PyTree in which the leaves have a shape (batch_size, num_features)

  • batch_of_descriptors (Descriptor) – an array that contains the descriptors of the aforementioned genotypes over all evals. Its shape is (batch_size, num_samples, num_descriptors). Note that we "aggregate" descriptors by finding the most frequent cell of each individual. Thus, the actual descriptors stored in the repertoire are just the coordinates of the centroid of the most frequent cell.

  • batch_of_fitnesses (Fitness) – an array that contains the fitnesses of the aforementioned genotypes over all evals. Its shape is (batch_size, num_samples)

  • batch_of_extra_scores (Optional[ExtraScores]) – unused tree that contains the extra_scores of aforementioned genotypes.

Returns:
  • MELSRepertoire – The updated repertoire.

Source code in qdax/core/containers/mels_repertoire.py
@jax.jit
def add(
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MELSRepertoire:
    """
    Add a batch of elements to the repertoire.

    The key difference between this method and the default add() in
    MapElitesRepertoire is that it expects each individual to be evaluated
    `num_samples` times, resulting in `num_samples` fitnesses and
    `num_samples` descriptors per individual.

    If multiple individuals may be added to a single cell, this method will
    arbitrarily pick one -- the exact choice depends on the implementation of
    jax.at[].set(), which can be non-deterministic:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
    We do not currently check if one of the multiple individuals dominates the
    others (dominate means that the individual has both highest fitness and lowest
    spread among the individuals for that cell).

    If `num_samples` is only 1, the spreads will default to 0.

    Args:
        batch_of_genotypes: a batch of genotypes to be added to the repertoire.
            Similarly to the self.genotypes argument, this is a PyTree in which
            the leaves have a shape (batch_size, num_features)
        batch_of_descriptors: an array that contains the descriptors of the
            aforementioned genotypes over all evals. Its shape is
            (batch_size, num_samples, num_descriptors). Note that we "aggregate"
            descriptors by finding the most frequent cell of each individual. Thus,
            the actual descriptors stored in the repertoire are just the coordinates
            of the centroid of the most frequent cell.
        batch_of_fitnesses: an array that contains the fitnesses of the
            aforementioned genotypes over all evals. Its shape is (batch_size,
            num_samples)
        batch_of_extra_scores: unused tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated repertoire.
    """
    batch_size, num_samples = batch_of_fitnesses.shape

    # Compute indices/cells of all descriptors.
    batch_of_all_indices = get_cells_indices(
        batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
    ).reshape((batch_size, num_samples))

    # Compute most frequent cell of each solution.
    batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]

    # Compute dispersion / spread. The dispersion is set to zero if
    # num_samples is 1.
    batch_of_spreads = jax.lax.cond(
        num_samples == 1,
        lambda desc: jnp.zeros(batch_size),
        lambda desc: jax.vmap(_dispersion)(
            desc.reshape((batch_size, num_samples, -1))
        ),
        batch_of_descriptors,
    )
    batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)

    # Compute canonical descriptors as the descriptor of the centroid of the most
    # frequent cell. Note that this line redefines the earlier batch_of_descriptors.
    batch_of_descriptors = jnp.take_along_axis(
        self.centroids, batch_of_indices, axis=0
    )

    # Compute canonical fitnesses as the average fitness.
    #
    # Shape: (batch_size, 1)
    batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)

    num_centroids = self.centroids.shape[0]

    # get current repertoire fitnesses and spreads
    repertoire_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
    current_fitnesses = jnp.take_along_axis(
        repertoire_fitnesses, batch_of_indices, 0
    )

    repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
    current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)

    # get addition condition
    addition_condition_fitness = batch_of_fitnesses > current_fitnesses
    addition_condition_spread = batch_of_spreads <= current_spreads
    addition_condition = jnp.logical_and(
        addition_condition_fitness, addition_condition_spread
    )

    # assign fake position when relevant : num_centroids is out of bound
    batch_of_indices = jnp.where(
        addition_condition, x=batch_of_indices, y=num_centroids
    )

    # create new repertoire
    new_repertoire_genotypes = jax.tree_util.tree_map(
        lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_fitnesses.squeeze(axis=-1)
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_descriptors
    )
    new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_spreads.squeeze(axis=-1)
    )

    return MELSRepertoire(
        genotypes=new_repertoire_genotypes,
        fitnesses=new_fitnesses,
        descriptors=new_descriptors,
        centroids=self.centroids,
        spreads=new_spreads,
    )
init_default(genotype, centroids) classmethod

Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MELS, so it can be called easily called from other modules.

Parameters:
  • genotype (Genotype) – the typical genotype that will be stored.

  • centroids (Centroid) – the centroids of the repertoire.

Returns:
  • MELSRepertoire – A repertoire filled with default values.

Source code in qdax/core/containers/mels_repertoire.py
@classmethod
def init_default(
    cls,
    genotype: Genotype,
    centroids: Centroid,
) -> MELSRepertoire:
    """Initialize a MAP-Elites Low-Spread repertoire with an initial population of
    genotypes. Requires the definition of centroids that can be computed with any
    method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MELS, so
    it can be called easily called from other modules.

    Args:
        genotype: the typical genotype that will be stored.
        centroids: the centroids of the repertoire.

    Returns:
        A repertoire filled with default values.
    """

    # get number of centroids
    num_centroids = centroids.shape[0]

    # default fitness is -inf
    default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

    # default genotypes is all 0
    default_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        genotype,
    )

    # default descriptor is all zeros
    default_descriptors = jnp.zeros_like(centroids)

    # default spread is inf so that any spread will be less
    default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)

    return cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
        spreads=default_spreads,
    )

mome_repertoire

This file contains the class to define the repertoire used to store individuals in the Multi-Objective MAP-Elites algorithm as well as several variants.

MOMERepertoire (MapElitesRepertoire) dataclass

Class for the repertoire in Multi Objective Map Elites

This class inherits from MAPElitesRepertoire. The stored data is the same: genotypes, fitnesses, descriptors, centroids.

The shape of genotypes is (in the case where it's an array): (num_centroids, pareto_front_length, genotype_dim). When the genotypes is a PyTree, the two first dimensions are the same but the third will depend on the leafs.

The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)

The shape of descriptors and centroids are: (num_centroids, num_descriptors, pareto_front_length).

Inherited functions: save and load.

Source code in qdax/core/containers/mome_repertoire.py
class MOMERepertoire(MapElitesRepertoire):
    """Class for the repertoire in Multi Objective Map Elites

    This class inherits from MAPElitesRepertoire. The stored data
    is the same: genotypes, fitnesses, descriptors, centroids.

    The shape of genotypes is (in the case where it's an array):
    (num_centroids, pareto_front_length, genotype_dim).
    When the genotypes is a PyTree, the two first dimensions are the same
    but the third will depend on the leafs.

    The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)

    The shape of descriptors and centroids are:
    (num_centroids, num_descriptors, pareto_front_length).

    Inherited functions: save and load.
    """

    @property
    def repertoire_capacity(self) -> int:
        """Returns the maximum number of solutions the repertoire can
        contain which corresponds to the number of cells times the
        maximum pareto front length.

        Returns:
            The repertoire capacity.
        """
        first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
        return int(first_leaf.shape[0] * first_leaf.shape[1])

    @jax.jit
    def _sample_in_masked_pareto_front(
        self,
        pareto_front_genotypes: ParetoFront[Genotype],
        mask: Mask,
        random_key: RNGKey,
    ) -> Genotype:
        """Sample one single genotype in masked pareto front.

        Note: do not retrieve a random key because this function
        is to be vmapped. The public method that uses this function
        will return a random key

        Args:
            pareto_front_genotypes: the genotypes of a pareto front
            mask: a mask associated to the front
            random_key: a random key to handle stochastic operations

        Returns:
            A single genotype among the pareto front.
        """
        p = (1.0 - mask) / jnp.sum(1.0 - mask)

        genotype_sample = jax.tree_util.tree_map(
            lambda x: jax.random.choice(random_key, x, shape=(1,), p=p),
            pareto_front_genotypes,
        )

        return genotype_sample

    @partial(jax.jit, static_argnames=("num_samples",))
    def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
        """Sample elements in the repertoire.

        This method sample a non-empty pareto front, and then sample
        genotypes from this pareto front.

        Args:
            random_key: a random key to handle stochasticity.
            num_samples: number of samples to retrieve from the repertoire.

        Returns:
            A sample of genotypes and a new random key.
        """

        # create sampling probability for the cells
        repertoire_empty = jnp.any(self.fitnesses == -jnp.inf, axis=-1)
        occupied_cells = jnp.any(~repertoire_empty, axis=-1)

        p = occupied_cells / jnp.sum(occupied_cells)

        # possible indices - num cells
        indices = jnp.arange(start=0, stop=repertoire_empty.shape[0])

        # choose idx - among indices of cells that are not empty
        random_key, subkey = jax.random.split(random_key)
        cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)

        # get genotypes (front) from the chosen indices
        pareto_front_genotypes = jax.tree_util.tree_map(
            lambda x: x[cells_idx], self.genotypes
        )

        # prepare second sampling function
        sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)

        # sample genotypes from the pareto front
        random_key, subkey = jax.random.split(random_key)
        subkeys = jax.random.split(subkey, num=num_samples)
        sampled_genotypes = sample_in_fronts(  # type: ignore
            pareto_front_genotypes=pareto_front_genotypes,
            mask=repertoire_empty[cells_idx],
            random_key=subkeys,
        )

        # remove the dim coming from pareto front
        sampled_genotypes = jax.tree_util.tree_map(
            lambda x: x.squeeze(axis=1), sampled_genotypes
        )

        return sampled_genotypes, random_key

    @jax.jit
    def _update_masked_pareto_front(
        self,
        pareto_front_fitnesses: ParetoFront[Fitness],
        pareto_front_genotypes: ParetoFront[Genotype],
        pareto_front_descriptors: ParetoFront[Descriptor],
        mask: Mask,
        new_batch_of_fitnesses: Fitness,
        new_batch_of_genotypes: Genotype,
        new_batch_of_descriptors: Descriptor,
        new_mask: Mask,
    ) -> Tuple[
        ParetoFront[Fitness], ParetoFront[Genotype], ParetoFront[Descriptor], Mask
    ]:
        """Takes a fixed size pareto front, its mask and new points to add.
        Returns updated front and mask.

        Args:
            pareto_front_fitnesses: fitness of the pareto front
            pareto_front_genotypes: corresponding genotypes
            pareto_front_descriptors: corresponding descriptors
            mask: mask of the front, to hide void parts
            new_batch_of_fitnesses: new batch of fitness that is considered
                to be added to the pareto front
            new_batch_of_genotypes: corresponding genotypes
            new_batch_of_descriptors: corresponding descriptors
            new_mask: corresponding mask (no one is masked)

        Returns:
            The updated pareto front.
        """
        # get dimensions
        batch_size = new_batch_of_fitnesses.shape[0]
        num_criteria = new_batch_of_fitnesses.shape[1]

        pareto_front_len = pareto_front_fitnesses.shape[0]  # type: ignore

        first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0]
        genotypes_dim = first_leaf.shape[1]

        descriptors_dim = new_batch_of_descriptors.shape[1]

        # gather all data
        cat_mask = jnp.concatenate([mask, new_mask], axis=-1)
        cat_fitnesses = jnp.concatenate(
            [pareto_front_fitnesses, new_batch_of_fitnesses], axis=0
        )
        cat_genotypes = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            pareto_front_genotypes,
            new_batch_of_genotypes,
        )
        cat_descriptors = jnp.concatenate(
            [pareto_front_descriptors, new_batch_of_descriptors], axis=0
        )

        # get new front
        cat_bool_front = compute_masked_pareto_front(
            batch_of_criteria=cat_fitnesses, mask=cat_mask
        )

        # get corresponding indices
        indices = (
            jnp.arange(start=0, stop=pareto_front_len + batch_size) * cat_bool_front
        )
        indices = indices + ~cat_bool_front * (batch_size + pareto_front_len - 1)
        indices = jnp.sort(indices)

        # get new fitness, genotypes and descriptors
        new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0)
        new_front_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.take(x, indices, axis=0), cat_genotypes
        )
        new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0)

        # compute new mask
        num_front_elements = jnp.sum(cat_bool_front)
        new_mask_indices = jnp.arange(start=0, stop=batch_size + pareto_front_len)
        new_mask_indices = (num_front_elements - new_mask_indices) > 0

        new_mask = jnp.where(
            new_mask_indices,
            jnp.ones(shape=batch_size + pareto_front_len, dtype=bool),
            jnp.zeros(shape=batch_size + pareto_front_len, dtype=bool),
        )

        fitness_mask = jnp.repeat(
            jnp.expand_dims(new_mask, axis=-1), num_criteria, axis=-1
        )
        new_front_fitness = new_front_fitness * fitness_mask

        front_size = len(pareto_front_fitnesses)  # type: ignore
        new_front_fitness = new_front_fitness[:front_size, :]

        genotypes_mask = jnp.repeat(
            jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1
        )
        new_front_genotypes = jax.tree_util.tree_map(
            lambda x: x * genotypes_mask, new_front_genotypes
        )
        new_front_genotypes = jax.tree_util.tree_map(
            lambda x: x[:front_size, :], new_front_genotypes
        )

        descriptors_mask = jnp.repeat(
            jnp.expand_dims(new_mask, axis=-1), descriptors_dim, axis=-1
        )
        new_front_descriptors = new_front_descriptors * descriptors_mask
        new_front_descriptors = new_front_descriptors[:front_size, :]

        new_mask = ~new_mask[:front_size]

        return new_front_fitness, new_front_genotypes, new_front_descriptors, new_mask

    @jax.jit
    def add(
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MOMERepertoire:
        """Insert a batch of elements in the repertoire.

        Shape of the batch_of_genotypes (if an array):
        (batch_size, genotypes_dim)
        Shape of the batch_of_descriptors: (batch_size, num_descriptors)
        Shape of the batch_of_fitnesses: (batch_size, num_criteria)

        Args:
            batch_of_genotypes: a batch of genotypes that we are trying to
                insert into the repertoire.
            batch_of_descriptors: the descriptors of the genotypes we are
                trying to add to the repertoire.
            batch_of_fitnesses: the fitnesses of the genotypes we are trying
                to add to the repertoire.
            batch_of_extra_scores: unused tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated repertoire with potential new individuals.
        """

        # get the indices that corresponds to the descriptors in the repertoire
        batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

        def _add_one(
            carry: MOMERepertoire,
            data: Tuple[Genotype, Descriptor, Fitness, jnp.ndarray],
        ) -> Tuple[MOMERepertoire, Any]:
            # unwrap data
            genotype, descriptors, fitness, index = data

            index = index.astype(jnp.int32)

            # get cell data
            cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
            cell_fitness = carry.fitnesses[index]
            cell_descriptor = carry.descriptors[index]
            cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

            # update pareto front
            (
                cell_fitness,
                cell_genotype,
                cell_descriptor,
                cell_mask,
            ) = self._update_masked_pareto_front(
                pareto_front_fitnesses=cell_fitness.squeeze(axis=0),
                pareto_front_genotypes=cell_genotype.squeeze(axis=0),
                pareto_front_descriptors=cell_descriptor.squeeze(axis=0),
                mask=cell_mask.squeeze(axis=0),
                new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
                new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
                new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
                new_mask=jnp.zeros(shape=(1,), dtype=bool),
            )

            # update cell fitness
            cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)

            # update grid
            new_genotypes = jax.tree_util.tree_map(
                lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
            )
            new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
            new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
            carry = carry.replace(  # type: ignore
                genotypes=new_genotypes,
                descriptors=new_descriptors,
                fitnesses=new_fitnesses,
            )

            # return new grid
            return carry, ()

        # scan the addition operation for all the data
        self, _ = jax.lax.scan(
            _add_one,
            self,
            (
                batch_of_genotypes,
                batch_of_descriptors,
                batch_of_fitnesses,
                batch_of_indices,
            ),
        )

        return self

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        centroids: Centroid,
        pareto_front_max_length: int,
        extra_scores: Optional[ExtraScores] = None,
    ) -> MOMERepertoire:
        """
        Initialize a Multi Objective Map-Elites repertoire with an initial population
        of genotypes. Requires the definition of centroids that can be computed with
        any method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so it can
        be called easily called from other modules.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape:
                (batch_size, num_criteria)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            pareto_front_max_length: maximum size of the pareto fronts
            extra_scores: unused extra_scores of the initial genotypes

        Returns:
            An initialized MAP-Elite repertoire
        """

        warnings.warn(
            (
                "This type of repertoire does not store the extra scores "
                "computed by the scoring function"
            ),
            stacklevel=2,
        )

        # get dimensions
        num_criteria = fitnesses.shape[1]
        num_descriptors = descriptors.shape[1]
        num_centroids = centroids.shape[0]

        # create default values
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(num_centroids, pareto_front_max_length, num_criteria)
        )
        default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(
                shape=(
                    num_centroids,
                    pareto_front_max_length,
                )
                + x.shape[1:]
            ),
            genotypes,
        )
        default_descriptors = jnp.zeros(
            shape=(num_centroids, pareto_front_max_length, num_descriptors)
        )

        # create repertoire with default values
        repertoire = MOMERepertoire(  # type: ignore
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
        )

        # add first batch of individuals in the repertoire
        new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

    @jax.jit
    def compute_global_pareto_front(
        self,
    ) -> Tuple[ParetoFront[Fitness], Mask]:
        """Merge all the pareto fronts of the MOME repertoire into a single one
        called global pareto front.

        Returns:
            The pareto front and its mask.
        """
        fitnesses = jnp.concatenate(self.fitnesses, axis=0)
        mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
        pareto_mask = compute_masked_pareto_front(fitnesses, mask)
        pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)

        return pareto_front, pareto_mask
repertoire_capacity: int property readonly

Returns the maximum number of solutions the repertoire can contain which corresponds to the number of cells times the maximum pareto front length.

Returns:
  • int – The repertoire capacity.

sample(self, random_key, num_samples)

Sample elements in the repertoire.

This method sample a non-empty pareto front, and then sample genotypes from this pareto front.

Parameters:
  • random_key (RNGKey) – a random key to handle stochasticity.

  • num_samples (int) – number of samples to retrieve from the repertoire.

Returns:
  • Tuple[Genotype, RNGKey] – A sample of genotypes and a new random key.

Source code in qdax/core/containers/mome_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
    """Sample elements in the repertoire.

    This method sample a non-empty pareto front, and then sample
    genotypes from this pareto front.

    Args:
        random_key: a random key to handle stochasticity.
        num_samples: number of samples to retrieve from the repertoire.

    Returns:
        A sample of genotypes and a new random key.
    """

    # create sampling probability for the cells
    repertoire_empty = jnp.any(self.fitnesses == -jnp.inf, axis=-1)
    occupied_cells = jnp.any(~repertoire_empty, axis=-1)

    p = occupied_cells / jnp.sum(occupied_cells)

    # possible indices - num cells
    indices = jnp.arange(start=0, stop=repertoire_empty.shape[0])

    # choose idx - among indices of cells that are not empty
    random_key, subkey = jax.random.split(random_key)
    cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)

    # get genotypes (front) from the chosen indices
    pareto_front_genotypes = jax.tree_util.tree_map(
        lambda x: x[cells_idx], self.genotypes
    )

    # prepare second sampling function
    sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)

    # sample genotypes from the pareto front
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey, num=num_samples)
    sampled_genotypes = sample_in_fronts(  # type: ignore
        pareto_front_genotypes=pareto_front_genotypes,
        mask=repertoire_empty[cells_idx],
        random_key=subkeys,
    )

    # remove the dim coming from pareto front
    sampled_genotypes = jax.tree_util.tree_map(
        lambda x: x.squeeze(axis=1), sampled_genotypes
    )

    return sampled_genotypes, random_key
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/mome_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Insert a batch of elements in the repertoire.

Shape of the batch_of_genotypes (if an array): (batch_size, genotypes_dim) Shape of the batch_of_descriptors: (batch_size, num_descriptors) Shape of the batch_of_fitnesses: (batch_size, num_criteria)

Parameters:
  • batch_of_genotypes (Genotype) – a batch of genotypes that we are trying to insert into the repertoire.

  • batch_of_descriptors (Descriptor) – the descriptors of the genotypes we are trying to add to the repertoire.

  • batch_of_fitnesses (Fitness) – the fitnesses of the genotypes we are trying to add to the repertoire.

  • batch_of_extra_scores (Optional[ExtraScores]) – unused tree that contains the extra_scores of aforementioned genotypes.

Returns:
  • MOMERepertoire – The updated repertoire with potential new individuals.

Source code in qdax/core/containers/mome_repertoire.py
@jax.jit
def add(
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
    """Insert a batch of elements in the repertoire.

    Shape of the batch_of_genotypes (if an array):
    (batch_size, genotypes_dim)
    Shape of the batch_of_descriptors: (batch_size, num_descriptors)
    Shape of the batch_of_fitnesses: (batch_size, num_criteria)

    Args:
        batch_of_genotypes: a batch of genotypes that we are trying to
            insert into the repertoire.
        batch_of_descriptors: the descriptors of the genotypes we are
            trying to add to the repertoire.
        batch_of_fitnesses: the fitnesses of the genotypes we are trying
            to add to the repertoire.
        batch_of_extra_scores: unused tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated repertoire with potential new individuals.
    """

    # get the indices that corresponds to the descriptors in the repertoire
    batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

    def _add_one(
        carry: MOMERepertoire,
        data: Tuple[Genotype, Descriptor, Fitness, jnp.ndarray],
    ) -> Tuple[MOMERepertoire, Any]:
        # unwrap data
        genotype, descriptors, fitness, index = data

        index = index.astype(jnp.int32)

        # get cell data
        cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
        cell_fitness = carry.fitnesses[index]
        cell_descriptor = carry.descriptors[index]
        cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

        # update pareto front
        (
            cell_fitness,
            cell_genotype,
            cell_descriptor,
            cell_mask,
        ) = self._update_masked_pareto_front(
            pareto_front_fitnesses=cell_fitness.squeeze(axis=0),
            pareto_front_genotypes=cell_genotype.squeeze(axis=0),
            pareto_front_descriptors=cell_descriptor.squeeze(axis=0),
            mask=cell_mask.squeeze(axis=0),
            new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
            new_batch_of_genotypes=jnp.expand_dims(genotype, axis=0),
            new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
            new_mask=jnp.zeros(shape=(1,), dtype=bool),
        )

        # update cell fitness
        cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)

        # update grid
        new_genotypes = jax.tree_util.tree_map(
            lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
        )
        new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
        new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
        carry = carry.replace(  # type: ignore
            genotypes=new_genotypes,
            descriptors=new_descriptors,
            fitnesses=new_fitnesses,
        )

        # return new grid
        return carry, ()

    # scan the addition operation for all the data
    self, _ = jax.lax.scan(
        _add_one,
        self,
        (
            batch_of_genotypes,
            batch_of_descriptors,
            batch_of_fitnesses,
            batch_of_indices,
        ),
    )

    return self
init(genotypes, fitnesses, descriptors, centroids, pareto_front_max_length, extra_scores=None) classmethod

Initialize a Multi Objective Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) – fitness of the initial genotypes of shape: (batch_size, num_criteria)

  • descriptors (Descriptor) – descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • centroids (Centroid) – tessellation centroids of shape (batch_size, num_descriptors)

  • pareto_front_max_length (int) – maximum size of the pareto fronts

  • extra_scores (Optional[ExtraScores]) – unused extra_scores of the initial genotypes

Returns:
  • MOMERepertoire – An initialized MAP-Elite repertoire

Source code in qdax/core/containers/mome_repertoire.py
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    centroids: Centroid,
    pareto_front_max_length: int,
    extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
    """
    Initialize a Multi Objective Map-Elites repertoire with an initial population
    of genotypes. Requires the definition of centroids that can be computed with
    any method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so it can
    be called easily called from other modules.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape:
            (batch_size, num_criteria)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        pareto_front_max_length: maximum size of the pareto fronts
        extra_scores: unused extra_scores of the initial genotypes

    Returns:
        An initialized MAP-Elite repertoire
    """

    warnings.warn(
        (
            "This type of repertoire does not store the extra scores "
            "computed by the scoring function"
        ),
        stacklevel=2,
    )

    # get dimensions
    num_criteria = fitnesses.shape[1]
    num_descriptors = descriptors.shape[1]
    num_centroids = centroids.shape[0]

    # create default values
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(num_centroids, pareto_front_max_length, num_criteria)
    )
    default_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(
            shape=(
                num_centroids,
                pareto_front_max_length,
            )
            + x.shape[1:]
        ),
        genotypes,
    )
    default_descriptors = jnp.zeros(
        shape=(num_centroids, pareto_front_max_length, num_descriptors)
    )

    # create repertoire with default values
    repertoire = MOMERepertoire(  # type: ignore
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
    )

    # add first batch of individuals in the repertoire
    new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    return new_repertoire  # type: ignore
compute_global_pareto_front(self)

Merge all the pareto fronts of the MOME repertoire into a single one called global pareto front.

Returns:
  • Tuple[ParetoFront[Fitness], Mask] – The pareto front and its mask.

Source code in qdax/core/containers/mome_repertoire.py
@jax.jit
def compute_global_pareto_front(
    self,
) -> Tuple[ParetoFront[Fitness], Mask]:
    """Merge all the pareto fronts of the MOME repertoire into a single one
    called global pareto front.

    Returns:
        The pareto front and its mask.
    """
    fitnesses = jnp.concatenate(self.fitnesses, axis=0)
    mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
    pareto_mask = compute_masked_pareto_front(fitnesses, mask)
    pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)

    return pareto_front, pareto_mask

nsga2_repertoire

NSGA2Repertoire (GARepertoire) dataclass

Repertoire used for the NSGA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, save, sample and init.

Source code in qdax/core/containers/nsga2_repertoire.py
class NSGA2Repertoire(GARepertoire):
    """Repertoire used for the NSGA2 algorithm.

    Inherits from the GARepertoire. The data stored are the genotypes
    and there fitness. Several functions are inherited from GARepertoire,
    including size, save, sample and init.
    """

    @jax.jit
    def _compute_crowding_distances(
        self, fitnesses: Fitness, mask: jnp.ndarray
    ) -> jnp.ndarray:
        """Compute crowding distances.

        The crowding distance is the Manhatten Distance in the objective
        space. This is used to rank individuals in the addition function.

        Args:
            fitnesses: fitnesses of the considered individuals. Here,
                fitness are vectors as we are doing multi-objective
                optimization.
            mask: a vector to mask values.

        Returns:
            The crowding distances.
        """
        # Retrieve only non masked solutions
        num_solutions = fitnesses.shape[0]
        num_objective = fitnesses.shape[1]
        if num_solutions <= 2:
            return jnp.array([jnp.inf] * num_solutions)

        else:
            # Sort solutions on each objective
            mask_dist = jnp.column_stack([mask] * fitnesses.shape[1])
            score_amplitude = jnp.max(fitnesses, axis=0) - jnp.min(fitnesses, axis=0)
            dist_fitnesses = (
                fitnesses + 3 * score_amplitude * jnp.ones_like(fitnesses) * mask_dist
            )
            sorted_index = jnp.argsort(dist_fitnesses, axis=0)
            srt_fitnesses = fitnesses[sorted_index, jnp.arange(num_objective)]

            # Calculate the norm for each objective - set to NaN if all values are equal
            norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0)

            # get the distances
            dists = jnp.row_stack(
                [srt_fitnesses, jnp.full(num_objective, jnp.inf)]
            ) - jnp.row_stack([jnp.full(num_objective, -jnp.inf), srt_fitnesses])

            # Prepare the distance to last and next vectors
            dist_to_last, dist_to_next = dists, dists
            dist_to_last = dists[:-1] / norm
            dist_to_next = dists[1:] / norm

            # Sum up the distances and reorder
            j = jnp.argsort(sorted_index, axis=0)
            crowding_distances = (
                jnp.sum(
                    (
                        dist_to_last[j, jnp.arange(num_objective)]
                        + dist_to_next[j, jnp.arange(num_objective)]
                    ),
                    axis=1,
                )
                / num_objective
            )

            return crowding_distances

    @jax.jit
    def add(
        self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
    ) -> NSGA2Repertoire:
        """Implements the repertoire addition rules.

        The population is sorted in successive pareto front. The first one
        is the global pareto front. The second one is the pareto front of the
        population where the first pareto front has been removed, etc...

        The successive pareto fronts are kept until the moment where adding a
        full pareto front would exceed the population size.

        To decide the survival of this pareto front, a crowding distance is
        computed in order to keep individuals that are spread in this last pareto
        front. Hence, the individuals with the biggest crowding distances are
        added until the population size is reached.

        Args:
            batch_of_genotypes: new genotypes that we try to add.
            batch_of_fitnesses: fitness of those new genotypes.

        Returns:
            The updated repertoire.
        """
        # All the candidates
        candidates = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )

        candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

        first_leaf = jax.tree_util.tree_leaves(candidates)[0]
        num_candidates = first_leaf.shape[0]

        def compute_current_front(
            val: Tuple[jnp.ndarray, jnp.ndarray]
        ) -> Tuple[jnp.ndarray, jnp.ndarray]:
            """Body function for the while loop. Computes the successive
            pareto fronts in the data.

            Args:
                val: Value passed through the while loop. Here, it is
                    a tuple containing two values. The indexes of all
                    solutions to keep and the indexes of the last
                    computed front.

            Returns:
                The updated values to pass through the while loop. Updated
                number of solutions and updated front indexes.
            """
            to_keep_index, _ = val

            # mask the individual that are already kept
            front_index = compute_masked_pareto_front(
                candidate_fitnesses, mask=to_keep_index
            )

            # Add new indexes
            to_keep_index = to_keep_index + front_index

            # Update front & number of solutions
            return to_keep_index, front_index

        def condition_fn_1(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
            """Gives condition to stop the while loop. Makes sure the
            the number of solution is smaller than the maximum size
            of the population.

            Args:
                val: Value passed through the while loop. Here, it is
                    a tuple containing two values. The indexes of all
                    solutions to keep and the indexes of the last
                    computed front.

            Returns:
                Returns True if we have reached the maximum number of
                solutions we can keep in the population.
            """
            to_keep_index, _ = val
            return sum(to_keep_index) < self.size  # type: ignore

        # get indexes of all first successive fronts and indexes of the last front
        to_keep_index, front_index = jax.lax.while_loop(
            condition_fn_1,
            compute_current_front,
            (
                jnp.zeros(num_candidates, dtype=bool),
                jnp.zeros(num_candidates, dtype=bool),
            ),
        )

        # remove the indexes of the last front - gives first indexes to keep
        new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
        new_index = new_index * (~front_index)
        to_keep_index = new_index > 0

        # Compute crowding distances
        crowding_distances = self._compute_crowding_distances(
            candidate_fitnesses, ~front_index
        )
        crowding_distances = crowding_distances * (front_index)
        highest_dist = jnp.argsort(crowding_distances)

        def add_to_front(val: Tuple[jnp.ndarray, float]) -> Tuple[jnp.ndarray, Any]:
            """Add the individual with a given distance to the front.
            A index is incremented to get the highest from the non
            selected individuals.

            Args:
                val: a tuple of two elements. A boolean vector with the positions that
                    will be kept, and a cursor with the number of individuals already
                    added during this process.

            Returns:
                The updated tuple, with the new booleans and the number of
                added elements.
            """
            front_index, num = val
            front_index = front_index.at[highest_dist[-num]].set(True)
            num = num + 1
            val = front_index, num
            return val

        def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
            """Gives condition to stop the while loop. Makes sure the
            the number of solution is smaller than the maximum size
            of the population."""
            front_index, _ = val
            return sum(to_keep_index + front_index) < self.size  # type: ignore

        # add the individuals with the highest distances
        front_index, _num = jax.lax.while_loop(
            condition_fn_2,
            add_to_front,
            (jnp.zeros(num_candidates, dtype=bool), 0),
        )

        # update index
        to_keep_index = to_keep_index + front_index

        # go from boolean vector to indices - offset by 1
        indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index

        # get rid of the zeros (that correspond to the False from the mask)
        fake_indice = num_candidates + 1  # bigger than all the other indices
        indices = jnp.where(indices == 0, x=fake_indice, y=indices)

        # sort the indices to remove the fake indices
        indices = jnp.sort(indices)[: self.size]

        # remove the offset
        indices = indices - 1

        # keep only the survivors
        new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
        new_scores = candidate_fitnesses[indices]

        new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores)

        return new_repertoire  # type: ignore
add(self, batch_of_genotypes, batch_of_fitnesses)

Implements the repertoire addition rules.

The population is sorted in successive pareto front. The first one is the global pareto front. The second one is the pareto front of the population where the first pareto front has been removed, etc...

The successive pareto fronts are kept until the moment where adding a full pareto front would exceed the population size.

To decide the survival of this pareto front, a crowding distance is computed in order to keep individuals that are spread in this last pareto front. Hence, the individuals with the biggest crowding distances are added until the population size is reached.

Parameters:
  • batch_of_genotypes (Genotype) – new genotypes that we try to add.

  • batch_of_fitnesses (Fitness) – fitness of those new genotypes.

Returns:
  • NSGA2Repertoire – The updated repertoire.

Source code in qdax/core/containers/nsga2_repertoire.py
@jax.jit
def add(
    self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> NSGA2Repertoire:
    """Implements the repertoire addition rules.

    The population is sorted in successive pareto front. The first one
    is the global pareto front. The second one is the pareto front of the
    population where the first pareto front has been removed, etc...

    The successive pareto fronts are kept until the moment where adding a
    full pareto front would exceed the population size.

    To decide the survival of this pareto front, a crowding distance is
    computed in order to keep individuals that are spread in this last pareto
    front. Hence, the individuals with the biggest crowding distances are
    added until the population size is reached.

    Args:
        batch_of_genotypes: new genotypes that we try to add.
        batch_of_fitnesses: fitness of those new genotypes.

    Returns:
        The updated repertoire.
    """
    # All the candidates
    candidates = jax.tree_util.tree_map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )

    candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

    first_leaf = jax.tree_util.tree_leaves(candidates)[0]
    num_candidates = first_leaf.shape[0]

    def compute_current_front(
        val: Tuple[jnp.ndarray, jnp.ndarray]
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Body function for the while loop. Computes the successive
        pareto fronts in the data.

        Args:
            val: Value passed through the while loop. Here, it is
                a tuple containing two values. The indexes of all
                solutions to keep and the indexes of the last
                computed front.

        Returns:
            The updated values to pass through the while loop. Updated
            number of solutions and updated front indexes.
        """
        to_keep_index, _ = val

        # mask the individual that are already kept
        front_index = compute_masked_pareto_front(
            candidate_fitnesses, mask=to_keep_index
        )

        # Add new indexes
        to_keep_index = to_keep_index + front_index

        # Update front & number of solutions
        return to_keep_index, front_index

    def condition_fn_1(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
        """Gives condition to stop the while loop. Makes sure the
        the number of solution is smaller than the maximum size
        of the population.

        Args:
            val: Value passed through the while loop. Here, it is
                a tuple containing two values. The indexes of all
                solutions to keep and the indexes of the last
                computed front.

        Returns:
            Returns True if we have reached the maximum number of
            solutions we can keep in the population.
        """
        to_keep_index, _ = val
        return sum(to_keep_index) < self.size  # type: ignore

    # get indexes of all first successive fronts and indexes of the last front
    to_keep_index, front_index = jax.lax.while_loop(
        condition_fn_1,
        compute_current_front,
        (
            jnp.zeros(num_candidates, dtype=bool),
            jnp.zeros(num_candidates, dtype=bool),
        ),
    )

    # remove the indexes of the last front - gives first indexes to keep
    new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
    new_index = new_index * (~front_index)
    to_keep_index = new_index > 0

    # Compute crowding distances
    crowding_distances = self._compute_crowding_distances(
        candidate_fitnesses, ~front_index
    )
    crowding_distances = crowding_distances * (front_index)
    highest_dist = jnp.argsort(crowding_distances)

    def add_to_front(val: Tuple[jnp.ndarray, float]) -> Tuple[jnp.ndarray, Any]:
        """Add the individual with a given distance to the front.
        A index is incremented to get the highest from the non
        selected individuals.

        Args:
            val: a tuple of two elements. A boolean vector with the positions that
                will be kept, and a cursor with the number of individuals already
                added during this process.

        Returns:
            The updated tuple, with the new booleans and the number of
            added elements.
        """
        front_index, num = val
        front_index = front_index.at[highest_dist[-num]].set(True)
        num = num + 1
        val = front_index, num
        return val

    def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool:
        """Gives condition to stop the while loop. Makes sure the
        the number of solution is smaller than the maximum size
        of the population."""
        front_index, _ = val
        return sum(to_keep_index + front_index) < self.size  # type: ignore

    # add the individuals with the highest distances
    front_index, _num = jax.lax.while_loop(
        condition_fn_2,
        add_to_front,
        (jnp.zeros(num_candidates, dtype=bool), 0),
    )

    # update index
    to_keep_index = to_keep_index + front_index

    # go from boolean vector to indices - offset by 1
    indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index

    # get rid of the zeros (that correspond to the False from the mask)
    fake_indice = num_candidates + 1  # bigger than all the other indices
    indices = jnp.where(indices == 0, x=fake_indice, y=indices)

    # sort the indices to remove the fake indices
    indices = jnp.sort(indices)[: self.size]

    # remove the offset
    indices = indices - 1

    # keep only the survivors
    new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
    new_scores = candidate_fitnesses[indices]

    new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores)

    return new_repertoire  # type: ignore
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/nsga2_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

repertoire

This file contains util functions and a class to define a repertoire, used to store individuals in the MAP-Elites algorithm as well as several variants.

Repertoire (PyTreeNode, ABC) dataclass

Abstract class for any repertoire of genotypes.

We decided not to add the attributes Genotypes even if it will be shared by all children classes because we want to keep the parent classes explicit and transparent.

Source code in qdax/core/containers/repertoire.py
class Repertoire(flax.struct.PyTreeNode, ABC):
    """Abstract class for any repertoire of genotypes.

    We decided not to add the attributes Genotypes even if
    it will be shared by all children classes because we want
    to keep the parent classes explicit and transparent.
    """

    @abstractclassmethod
    def init(cls) -> Repertoire:  # noqa: N805
        """Create a repertoire."""
        pass

    @abstractmethod
    def sample(
        self,
        random_key: RNGKey,
        num_samples: int,
    ) -> Genotype:
        """Sample genotypes from the repertoire.

        Args:
            random_key: a random key to handle stochasticity.
            num_samples: the number of genotypes to sample.

        Returns:
            The sample of genotypes.
        """
        pass

    @abstractmethod
    def add(self) -> Repertoire:
        """Implements the rule to add new genotypes to a
        repertoire.

        Returns:
            The udpated repertoire.
        """
        pass
init() classmethod

Create a repertoire.

Source code in qdax/core/containers/repertoire.py
@abstractclassmethod
def init(cls) -> Repertoire:  # noqa: N805
    """Create a repertoire."""
    pass
sample(self, random_key, num_samples)

Sample genotypes from the repertoire.

Parameters:
  • random_key (RNGKey) – a random key to handle stochasticity.

  • num_samples (int) – the number of genotypes to sample.

Returns:
  • Genotype – The sample of genotypes.

Source code in qdax/core/containers/repertoire.py
@abstractmethod
def sample(
    self,
    random_key: RNGKey,
    num_samples: int,
) -> Genotype:
    """Sample genotypes from the repertoire.

    Args:
        random_key: a random key to handle stochasticity.
        num_samples: the number of genotypes to sample.

    Returns:
        The sample of genotypes.
    """
    pass
add(self)

Implements the rule to add new genotypes to a repertoire.

Returns:
  • Repertoire – The udpated repertoire.

Source code in qdax/core/containers/repertoire.py
@abstractmethod
def add(self) -> Repertoire:
    """Implements the rule to add new genotypes to a
    repertoire.

    Returns:
        The udpated repertoire.
    """
    pass
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

spea2_repertoire

SPEA2Repertoire (GARepertoire) dataclass

Repertoire used for the SPEA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, save, sample.

Source code in qdax/core/containers/spea2_repertoire.py
class SPEA2Repertoire(GARepertoire):
    """Repertoire used for the SPEA2 algorithm.

    Inherits from the GARepertoire. The data stored are the genotypes
    and there fitness. Several functions are inherited from GARepertoire,
    including size, save, sample.
    """

    num_neighbours: int = flax.struct.field(pytree_node=False)

    @jax.jit
    def _compute_strength_scores(self, batch_of_fitnesses: Fitness) -> jnp.ndarray:
        """Compute the strength scores (defined for a solution by the number of
        solutions dominating it plus the inverse of the density of solution in the
        fitness space).

        Args:
            batch_of_fitnesses: a batch of fitness vectors.

        Returns:
            Strength score of each solution corresponding to the fitnesses.
        """
        fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses), axis=0)
        # dominating solutions
        dominates = jnp.all(
            (fitnesses - jnp.expand_dims(fitnesses, axis=1)) > 0, axis=-1
        )
        strength_scores = jnp.sum(dominates, axis=1)

        # density
        distance_matrix = jnp.sum(
            (fitnesses - jnp.expand_dims(fitnesses, axis=1)) ** 2, axis=-1
        )
        densities = jnp.sum(
            jnp.sort(distance_matrix, axis=1)[:, : self.num_neighbours + 1], axis=1
        )

        # sum both terms
        strength_scores = strength_scores + 1 / (1 + densities)
        strength_scores = jnp.nan_to_num(strength_scores, nan=self.size + 2)

        return strength_scores

    @jax.jit
    def add(
        self,
        batch_of_genotypes: Genotype,
        batch_of_fitnesses: Fitness,
    ) -> SPEA2Repertoire:
        """Updates the population with the new solutions.

        To decide which individuals to keep, we count, for each solution,
        the number of solutions by which tey are dominated. We keep only
        the solutions that are the less dominated ones.

        Args:
            batch_of_genotypes: genotypes of the new individuals that are
                considered to be added to the population.
            batch_of_fitnesses: their corresponding fitnesses.

        Returns:
            Updated repertoire.
        """
        # All the candidates
        candidates = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )

        candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

        # compute strength score for all solutions
        strength_scores = self._compute_strength_scores(batch_of_fitnesses)

        # sort the strengths (the smaller the better (sic, respect paper's notation))
        indices = jnp.argsort(strength_scores)[: self.size]

        # keep the survivors
        new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
        new_fitnesses = candidates_fitnesses[indices]

        new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses)

        return new_repertoire  # type: ignore

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        population_size: int,
        num_neighbours: int,
    ) -> GARepertoire:
        """Initializes the repertoire.

        Start with default values and adds a first batch of genotypes
        to the repertoire.

        Args:
            genotypes: first batch of genotypes
            fitnesses: corresponding fitnesses
            population_size: size of the population we want to evolve

        Returns:
            An initial repertoire.
        """
        # create default fitnesses
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(population_size, fitnesses.shape[-1])
        )

        # create default genotypes
        default_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
        )

        # create an initial repertoire with those default values
        repertoire = cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            num_neighbours=num_neighbours,
        )

        new_repertoire = repertoire.add(genotypes, fitnesses)

        return new_repertoire  # type: ignore
add(self, batch_of_genotypes, batch_of_fitnesses)

Updates the population with the new solutions.

To decide which individuals to keep, we count, for each solution, the number of solutions by which tey are dominated. We keep only the solutions that are the less dominated ones.

Parameters:
  • batch_of_genotypes (Genotype) – genotypes of the new individuals that are considered to be added to the population.

  • batch_of_fitnesses (Fitness) – their corresponding fitnesses.

Returns:
  • SPEA2Repertoire – Updated repertoire.

Source code in qdax/core/containers/spea2_repertoire.py
@jax.jit
def add(
    self,
    batch_of_genotypes: Genotype,
    batch_of_fitnesses: Fitness,
) -> SPEA2Repertoire:
    """Updates the population with the new solutions.

    To decide which individuals to keep, we count, for each solution,
    the number of solutions by which tey are dominated. We keep only
    the solutions that are the less dominated ones.

    Args:
        batch_of_genotypes: genotypes of the new individuals that are
            considered to be added to the population.
        batch_of_fitnesses: their corresponding fitnesses.

    Returns:
        Updated repertoire.
    """
    # All the candidates
    candidates = jax.tree_util.tree_map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )

    candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

    # compute strength score for all solutions
    strength_scores = self._compute_strength_scores(batch_of_fitnesses)

    # sort the strengths (the smaller the better (sic, respect paper's notation))
    indices = jnp.argsort(strength_scores)[: self.size]

    # keep the survivors
    new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates)
    new_fitnesses = candidates_fitnesses[indices]

    new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses)

    return new_repertoire  # type: ignore
init(genotypes, fitnesses, population_size, num_neighbours) classmethod

Initializes the repertoire.

Start with default values and adds a first batch of genotypes to the repertoire.

Parameters:
  • genotypes (Genotype) – first batch of genotypes

  • fitnesses (Fitness) – corresponding fitnesses

  • population_size (int) – size of the population we want to evolve

Returns:
  • GARepertoire – An initial repertoire.

Source code in qdax/core/containers/spea2_repertoire.py
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    population_size: int,
    num_neighbours: int,
) -> GARepertoire:
    """Initializes the repertoire.

    Start with default values and adds a first batch of genotypes
    to the repertoire.

    Args:
        genotypes: first batch of genotypes
        fitnesses: corresponding fitnesses
        population_size: size of the population we want to evolve

    Returns:
        An initial repertoire.
    """
    # create default fitnesses
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(population_size, fitnesses.shape[-1])
    )

    # create default genotypes
    default_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
    )

    # create an initial repertoire with those default values
    repertoire = cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        num_neighbours=num_neighbours,
    )

    new_repertoire = repertoire.add(genotypes, fitnesses)

    return new_repertoire  # type: ignore
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/spea2_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

uniform_replacement_archive

UniformReplacementArchive (Archive) dataclass

Stores jnp.ndarray and use a uniform replacement when the maximum size is reached.

Instead of replacing elements in a FIFO manner, like the Archive, this implementation removes elements uniformly to replace them by the newly added ones.

Most methods are inherited from Archive.

Source code in qdax/core/containers/uniform_replacement_archive.py
class UniformReplacementArchive(Archive):
    """Stores jnp.ndarray and use a uniform replacement when the
    maximum size is reached.

    Instead of replacing elements in a FIFO manner, like the Archive,
    this implementation removes elements uniformly to replace them by
    the newly added ones.

    Most methods are inherited from Archive.
    """

    random_key: RNGKey

    @classmethod
    def create(  # type: ignore
        cls,
        acceptance_threshold: float,
        state_descriptor_size: int,
        max_size: int,
        random_key: RNGKey,
    ) -> Archive:
        """Create an Archive instance.

        This class method provides a convenient way to create the archive while
        keeping the __init__ function for more general way to init an archive.

        Args:
            acceptance_threshold: the minimal distance to a stored descriptor to
                be respected for a new descriptor to be added.
            state_descriptor_size: the number of elements in a state descriptor.
            max_size: the maximal size of the archive. In case of overflow, previous
                elements are replaced by new ones. Defaults to 80000.
            random_key: a key to handle random operations. Defaults to key with
                seed = 0.

        Returns:
            A newly initialized archive.
        """

        archive = super().create(
            acceptance_threshold,
            state_descriptor_size,
            max_size,
        )

        return archive.replace(random_key=random_key)  # type: ignore

    @jax.jit
    def _single_insertion(self, state_descriptor: jnp.ndarray) -> Archive:
        """Insert a single element.

        If the archive is not full yet, the new element replaces a fake
        border, if it is full, it replaces a random element from the archive.

        Args:
            state_descriptor: state descriptor to be added.

        Returns:
            Return the archive with the newly added element."""
        new_current_position = self.current_position + 1
        is_full = new_current_position >= self.max_size

        random_key, subkey = jax.random.split(self.random_key)
        random_index = jax.random.randint(
            subkey, shape=(1,), minval=0, maxval=self.max_size
        )

        index = jnp.where(condition=is_full, x=random_index, y=new_current_position)

        new_data = self.data.at[index].set(state_descriptor)

        return self.replace(  # type: ignore
            current_position=new_current_position, data=new_data, random_key=random_key
        )
create(acceptance_threshold, state_descriptor_size, max_size, random_key) classmethod

Create an Archive instance.

This class method provides a convenient way to create the archive while keeping the init function for more general way to init an archive.

Parameters:
  • acceptance_threshold (float) – the minimal distance to a stored descriptor to be respected for a new descriptor to be added.

  • state_descriptor_size (int) – the number of elements in a state descriptor.

  • max_size (int) – the maximal size of the archive. In case of overflow, previous elements are replaced by new ones. Defaults to 80000.

  • random_key (RNGKey) – a key to handle random operations. Defaults to key with seed = 0.

Returns:
  • Archive – A newly initialized archive.

Source code in qdax/core/containers/uniform_replacement_archive.py
@classmethod
def create(  # type: ignore
    cls,
    acceptance_threshold: float,
    state_descriptor_size: int,
    max_size: int,
    random_key: RNGKey,
) -> Archive:
    """Create an Archive instance.

    This class method provides a convenient way to create the archive while
    keeping the __init__ function for more general way to init an archive.

    Args:
        acceptance_threshold: the minimal distance to a stored descriptor to
            be respected for a new descriptor to be added.
        state_descriptor_size: the number of elements in a state descriptor.
        max_size: the maximal size of the archive. In case of overflow, previous
            elements are replaced by new ones. Defaults to 80000.
        random_key: a key to handle random operations. Defaults to key with
            seed = 0.

    Returns:
        A newly initialized archive.
    """

    archive = super().create(
        acceptance_threshold,
        state_descriptor_size,
        max_size,
    )

    return archive.replace(random_key=random_key)  # type: ignore
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/uniform_replacement_archive.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

unstructured_repertoire

UnstructuredRepertoire (PyTreeNode) dataclass

Class for the unstructured repertoire in Map Elites.

Parameters:
  • genotypes (Genotype) – a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple Jax array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses (Fitness) – an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors (Descriptor) – an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • centroids – an array the contains the centroids of the tesselation. The array shape is (num_centroids, num_descriptors).

  • observations (Observation) – observations that the genotype gathered in the environment.

Source code in qdax/core/containers/unstructured_repertoire.py
class UnstructuredRepertoire(flax.struct.PyTreeNode):
    """
    Class for the unstructured repertoire in Map Elites.

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple Jax array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        centroids: an array the contains the centroids of the tesselation. The array
            shape is (num_centroids, num_descriptors).
        observations: observations that the genotype gathered in the environment.
    """

    genotypes: Genotype
    fitnesses: Fitness
    descriptors: Descriptor
    observations: Observation
    l_value: jnp.ndarray
    max_size: int = flax.struct.field(pytree_node=False)

    def get_maximal_size(self) -> int:
        """Returns the maximal number of individuals in the repertoire."""
        return self.max_size

    def get_number_genotypes(self) -> jnp.ndarray:
        """Returns the number of genotypes in the repertoire."""
        return jnp.sum(self.fitnesses != -jnp.inf)

    def save(self, path: str = "./") -> None:
        """Saves the grid on disk in the form of .npy files.

        Flattens the genotypes to store it with .npy format. Supposes that
        a user will have access to the reconstruction function when loading
        the genotypes.

        Args:
            path: Path where the data will be saved. Defaults to "./".
        """

        def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
            flatten_genotype, _unravel_pytree = ravel_pytree(genotype)
            return flatten_genotype

        # flatten all the genotypes
        flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

        # save data
        jnp.save(path + "genotypes.npy", flat_genotypes)
        jnp.save(path + "fitnesses.npy", self.fitnesses)
        jnp.save(path + "descriptors.npy", self.descriptors)
        jnp.save(path + "observations.npy", self.observations)
        jnp.save(path + "l_value.npy", self.l_value)
        jnp.save(path + "max_size.npy", self.max_size)

    @classmethod
    def load(
        cls, reconstruction_fn: Callable, path: str = "./"
    ) -> UnstructuredRepertoire:
        """Loads an unstructured repertoire.

        Args:
            reconstruction_fn: Function to reconstruct a PyTree
                from a flat array.
            path: Path where the data is saved. Defaults to "./".

        Returns:
            An unstructured repertoire.
        """

        flat_genotypes = jnp.load(path + "genotypes.npy")
        genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

        fitnesses = jnp.load(path + "fitnesses.npy")
        descriptors = jnp.load(path + "descriptors.npy")
        observations = jnp.load(path + "observations.npy")
        l_value = jnp.load(path + "l_value.npy")
        max_size = int(jnp.load(path + "max_size.npy").item())

        return UnstructuredRepertoire(
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            observations=observations,
            l_value=l_value,
            max_size=max_size,
        )

    @jax.jit
    def add(
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_observations: Observation,
    ) -> UnstructuredRepertoire:
        """Adds a batch of genotypes to the repertoire.

        Args:
            batch_of_genotypes: genotypes of the individuals to be considered
                for addition in the repertoire.
            batch_of_descriptors: associated descriptors.
            batch_of_fitnesses: associated fitness.
            batch_of_observations: associated observations.

        Returns:
            A new unstructured repertoire where the relevant individuals have been
            added.
        """

        # We need to replace all the descriptors that are not filled with jnp inf
        filtered_descriptors = jnp.where(
            jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
            jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
            self.descriptors,
        )

        batch_of_indices, batch_of_distances = get_cells_indices(
            batch_of_descriptors, filtered_descriptors, 2
        )

        # Save the second-nearest neighbours to check a condition
        second_neighbours = batch_of_distances.at[..., 1].get()

        # Keep the Nearest neighbours
        batch_of_indices = batch_of_indices.at[..., 0].get()

        # Keep the Nearest neighbours
        batch_of_distances = batch_of_distances.at[..., 0].get()

        # We remove individuals that are too close to the second nn.
        # This avoids having clusters of individuals after adding them.
        not_novel_enough = jnp.where(
            jnp.squeeze(second_neighbours <= self.l_value), True, False
        )

        # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
        batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
        batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1)

        # TODO: Doesn't Work if Archive is full. Need to use the closest individuals
        # in that case.
        empty_indexes = jnp.squeeze(
            jnp.nonzero(
                jnp.where(jnp.isinf(self.fitnesses), 1, 0),
                size=batch_of_indices.shape[0],
                fill_value=-1,
            )[0]
        )
        batch_of_indices = jnp.where(
            jnp.squeeze(batch_of_distances <= self.l_value),
            jnp.squeeze(batch_of_indices),
            -1,
        )

        # We get all the indices of the empty bds first and then the filled ones
        # (because of -1)
        sorted_bds = jax.lax.top_k(
            -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
        )[1]
        batch_of_indices = jnp.where(
            jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value),
            batch_of_indices.at[sorted_bds].get(),
            empty_indexes,
        )

        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

        # ReIndexing of all the inputs to the correct sorted way
        batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get()
        batch_of_genotypes = jax.tree_map(
            lambda x: x.at[sorted_bds].get(), batch_of_genotypes
        )
        batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get()
        batch_of_observations = batch_of_observations.at[sorted_bds].get()
        not_novel_enough = not_novel_enough.at[sorted_bds].get()

        # Check to find Individuals with same BD within the Batch
        keep_indiv = jax.jit(
            jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
        )(
            batch_of_descriptors.squeeze(),
            jnp.arange(
                0, batch_of_descriptors.shape[0], 1
            ),  # keep track of where we are in the batch to assure right comparisons
            batch_of_descriptors.squeeze(),
            batch_of_fitnesses.squeeze(),
            self.l_value,
        )

        keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))

        # get fitness segment max
        best_fitnesses = jax.ops.segment_max(
            batch_of_fitnesses,
            batch_of_indices.astype(jnp.int32).squeeze(),
            num_segments=self.max_size,
        )

        cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

        # put dominated fitness to -jnp.inf
        batch_of_fitnesses = jnp.where(
            batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
        )

        # get addition condition
        grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
        current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0)
        addition_condition = batch_of_fitnesses > current_fitnesses
        addition_condition = jnp.logical_and(
            addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
        )

        # assign fake position when relevant : num_centroids is out of bounds
        batch_of_indices = jnp.where(
            addition_condition,
            x=batch_of_indices,
            y=self.max_size,
        )

        # create new grid
        new_grid_genotypes = jax.tree_map(
            lambda grid_genotypes, new_genotypes: grid_genotypes.at[
                batch_of_indices.squeeze()
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
            batch_of_fitnesses.squeeze()
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
            batch_of_descriptors.squeeze()
        )

        new_observations = self.observations.at[batch_of_indices.squeeze()].set(
            batch_of_observations.squeeze()
        )

        return UnstructuredRepertoire(
            genotypes=new_grid_genotypes,
            fitnesses=new_fitnesses.squeeze(),
            descriptors=new_descriptors.squeeze(),
            observations=new_observations.squeeze(),
            l_value=self.l_value,
            max_size=self.max_size,
        )

    @partial(jax.jit, static_argnames=("num_samples",))
    def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
        """Sample elements in the repertoire.

        Args:
            random_key: a jax PRNG random key
            num_samples: the number of elements to be sampled

        Returns:
            samples: a batch of genotypes sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        random_key, sub_key = jax.random.split(random_key)
        grid_empty = self.fitnesses == -jnp.inf
        p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty)

        samples = jax.tree_map(
            lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
            self.genotypes,
        )

        return samples, random_key

    @classmethod
    def init(
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        observations: Observation,
        l_value: jnp.ndarray,
        max_size: int,
    ) -> UnstructuredRepertoire:
        """Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape (batch_size,)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            observations: observations experienced in the evaluation task.
            l_value: threshold distance of the repertoire.
            max_size: maximal size of the container

        Returns:
            an initialized unstructured repertoire.
        """

        # Initialize grid with default values
        default_fitnesses = -jnp.inf * jnp.ones(shape=max_size)
        default_genotypes = jax.tree_map(
            lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
            genotypes,
        )
        default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))

        default_observations = jnp.full(
            shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan
        )

        repertoire = UnstructuredRepertoire(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            observations=default_observations,
            l_value=l_value,
            max_size=max_size,
        )

        return repertoire.add(  # type: ignore
            genotypes, descriptors, fitnesses, observations
        )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/containers/unstructured_repertoire.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
get_maximal_size(self)

Returns the maximal number of individuals in the repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
def get_maximal_size(self) -> int:
    """Returns the maximal number of individuals in the repertoire."""
    return self.max_size
get_number_genotypes(self)

Returns the number of genotypes in the repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
def get_number_genotypes(self) -> jnp.ndarray:
    """Returns the number of genotypes in the repertoire."""
    return jnp.sum(self.fitnesses != -jnp.inf)
save(self, path='./')

Saves the grid on disk in the form of .npy files.

Flattens the genotypes to store it with .npy format. Supposes that a user will have access to the reconstruction function when loading the genotypes.

Parameters:
  • path (str) – Path where the data will be saved. Defaults to "./".

Source code in qdax/core/containers/unstructured_repertoire.py
def save(self, path: str = "./") -> None:
    """Saves the grid on disk in the form of .npy files.

    Flattens the genotypes to store it with .npy format. Supposes that
    a user will have access to the reconstruction function when loading
    the genotypes.

    Args:
        path: Path where the data will be saved. Defaults to "./".
    """

    def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
        flatten_genotype, _unravel_pytree = ravel_pytree(genotype)
        return flatten_genotype

    # flatten all the genotypes
    flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)

    # save data
    jnp.save(path + "genotypes.npy", flat_genotypes)
    jnp.save(path + "fitnesses.npy", self.fitnesses)
    jnp.save(path + "descriptors.npy", self.descriptors)
    jnp.save(path + "observations.npy", self.observations)
    jnp.save(path + "l_value.npy", self.l_value)
    jnp.save(path + "max_size.npy", self.max_size)
load(reconstruction_fn, path='./') classmethod

Loads an unstructured repertoire.

Parameters:
  • reconstruction_fn (Callable) – Function to reconstruct a PyTree from a flat array.

  • path (str) – Path where the data is saved. Defaults to "./".

Returns:
  • UnstructuredRepertoire – An unstructured repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
@classmethod
def load(
    cls, reconstruction_fn: Callable, path: str = "./"
) -> UnstructuredRepertoire:
    """Loads an unstructured repertoire.

    Args:
        reconstruction_fn: Function to reconstruct a PyTree
            from a flat array.
        path: Path where the data is saved. Defaults to "./".

    Returns:
        An unstructured repertoire.
    """

    flat_genotypes = jnp.load(path + "genotypes.npy")
    genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)

    fitnesses = jnp.load(path + "fitnesses.npy")
    descriptors = jnp.load(path + "descriptors.npy")
    observations = jnp.load(path + "observations.npy")
    l_value = jnp.load(path + "l_value.npy")
    max_size = int(jnp.load(path + "max_size.npy").item())

    return UnstructuredRepertoire(
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        observations=observations,
        l_value=l_value,
        max_size=max_size,
    )
add(self, batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_observations)

Adds a batch of genotypes to the repertoire.

Parameters:
  • batch_of_genotypes (Genotype) – genotypes of the individuals to be considered for addition in the repertoire.

  • batch_of_descriptors (Descriptor) – associated descriptors.

  • batch_of_fitnesses (Fitness) – associated fitness.

  • batch_of_observations (Observation) – associated observations.

Returns:
  • UnstructuredRepertoire – A new unstructured repertoire where the relevant individuals have been added.

Source code in qdax/core/containers/unstructured_repertoire.py
@jax.jit
def add(
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_observations: Observation,
) -> UnstructuredRepertoire:
    """Adds a batch of genotypes to the repertoire.

    Args:
        batch_of_genotypes: genotypes of the individuals to be considered
            for addition in the repertoire.
        batch_of_descriptors: associated descriptors.
        batch_of_fitnesses: associated fitness.
        batch_of_observations: associated observations.

    Returns:
        A new unstructured repertoire where the relevant individuals have been
        added.
    """

    # We need to replace all the descriptors that are not filled with jnp inf
    filtered_descriptors = jnp.where(
        jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
        jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
        self.descriptors,
    )

    batch_of_indices, batch_of_distances = get_cells_indices(
        batch_of_descriptors, filtered_descriptors, 2
    )

    # Save the second-nearest neighbours to check a condition
    second_neighbours = batch_of_distances.at[..., 1].get()

    # Keep the Nearest neighbours
    batch_of_indices = batch_of_indices.at[..., 0].get()

    # Keep the Nearest neighbours
    batch_of_distances = batch_of_distances.at[..., 0].get()

    # We remove individuals that are too close to the second nn.
    # This avoids having clusters of individuals after adding them.
    not_novel_enough = jnp.where(
        jnp.squeeze(second_neighbours <= self.l_value), True, False
    )

    # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
    batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
    batch_of_observations = jnp.expand_dims(batch_of_observations, axis=-1)

    # TODO: Doesn't Work if Archive is full. Need to use the closest individuals
    # in that case.
    empty_indexes = jnp.squeeze(
        jnp.nonzero(
            jnp.where(jnp.isinf(self.fitnesses), 1, 0),
            size=batch_of_indices.shape[0],
            fill_value=-1,
        )[0]
    )
    batch_of_indices = jnp.where(
        jnp.squeeze(batch_of_distances <= self.l_value),
        jnp.squeeze(batch_of_indices),
        -1,
    )

    # We get all the indices of the empty bds first and then the filled ones
    # (because of -1)
    sorted_bds = jax.lax.top_k(
        -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
    )[1]
    batch_of_indices = jnp.where(
        jnp.squeeze(batch_of_distances.at[sorted_bds].get() <= self.l_value),
        batch_of_indices.at[sorted_bds].get(),
        empty_indexes,
    )

    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

    # ReIndexing of all the inputs to the correct sorted way
    batch_of_descriptors = batch_of_descriptors.at[sorted_bds].get()
    batch_of_genotypes = jax.tree_map(
        lambda x: x.at[sorted_bds].get(), batch_of_genotypes
    )
    batch_of_fitnesses = batch_of_fitnesses.at[sorted_bds].get()
    batch_of_observations = batch_of_observations.at[sorted_bds].get()
    not_novel_enough = not_novel_enough.at[sorted_bds].get()

    # Check to find Individuals with same BD within the Batch
    keep_indiv = jax.jit(
        jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
    )(
        batch_of_descriptors.squeeze(),
        jnp.arange(
            0, batch_of_descriptors.shape[0], 1
        ),  # keep track of where we are in the batch to assure right comparisons
        batch_of_descriptors.squeeze(),
        batch_of_fitnesses.squeeze(),
        self.l_value,
    )

    keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))

    # get fitness segment max
    best_fitnesses = jax.ops.segment_max(
        batch_of_fitnesses,
        batch_of_indices.astype(jnp.int32).squeeze(),
        num_segments=self.max_size,
    )

    cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

    # put dominated fitness to -jnp.inf
    batch_of_fitnesses = jnp.where(
        batch_of_fitnesses == cond_values, x=batch_of_fitnesses, y=-jnp.inf
    )

    # get addition condition
    grid_fitnesses = jnp.expand_dims(self.fitnesses, axis=-1)
    current_fitnesses = jnp.take_along_axis(grid_fitnesses, batch_of_indices, 0)
    addition_condition = batch_of_fitnesses > current_fitnesses
    addition_condition = jnp.logical_and(
        addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
    )

    # assign fake position when relevant : num_centroids is out of bounds
    batch_of_indices = jnp.where(
        addition_condition,
        x=batch_of_indices,
        y=self.max_size,
    )

    # create new grid
    new_grid_genotypes = jax.tree_map(
        lambda grid_genotypes, new_genotypes: grid_genotypes.at[
            batch_of_indices.squeeze()
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
        batch_of_fitnesses.squeeze()
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
        batch_of_descriptors.squeeze()
    )

    new_observations = self.observations.at[batch_of_indices.squeeze()].set(
        batch_of_observations.squeeze()
    )

    return UnstructuredRepertoire(
        genotypes=new_grid_genotypes,
        fitnesses=new_fitnesses.squeeze(),
        descriptors=new_descriptors.squeeze(),
        observations=new_observations.squeeze(),
        l_value=self.l_value,
        max_size=self.max_size,
    )
sample(self, random_key, num_samples)

Sample elements in the repertoire.

Parameters:
  • random_key (RNGKey) – a jax PRNG random key

  • num_samples (int) – the number of elements to be sampled

Returns:
  • samples – a batch of genotypes sampled in the repertoire random_key: an updated jax PRNG random key

Source code in qdax/core/containers/unstructured_repertoire.py
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
    """Sample elements in the repertoire.

    Args:
        random_key: a jax PRNG random key
        num_samples: the number of elements to be sampled

    Returns:
        samples: a batch of genotypes sampled in the repertoire
        random_key: an updated jax PRNG random key
    """

    random_key, sub_key = jax.random.split(random_key)
    grid_empty = self.fitnesses == -jnp.inf
    p = (1.0 - grid_empty) / jnp.sum(1.0 - grid_empty)

    samples = jax.tree_map(
        lambda x: jax.random.choice(sub_key, x, shape=(num_samples,), p=p),
        self.genotypes,
    )

    return samples, random_key
init(genotypes, fitnesses, descriptors, observations, l_value, max_size) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • genotypes (Genotype) – initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) – fitness of the initial genotypes of shape (batch_size,)

  • descriptors (Descriptor) – descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • observations (Observation) – observations experienced in the evaluation task.

  • l_value (jnp.ndarray) – threshold distance of the repertoire.

  • max_size (int) – maximal size of the container

Returns:
  • UnstructuredRepertoire – an initialized unstructured repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
@classmethod
def init(
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    observations: Observation,
    l_value: jnp.ndarray,
    max_size: int,
) -> UnstructuredRepertoire:
    """Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape (batch_size,)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        observations: observations experienced in the evaluation task.
        l_value: threshold distance of the repertoire.
        max_size: maximal size of the container

    Returns:
        an initialized unstructured repertoire.
    """

    # Initialize grid with default values
    default_fitnesses = -jnp.inf * jnp.ones(shape=max_size)
    default_genotypes = jax.tree_map(
        lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
        genotypes,
    )
    default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))

    default_observations = jnp.full(
        shape=(max_size,) + observations.shape[1:], fill_value=jnp.nan
    )

    repertoire = UnstructuredRepertoire(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        observations=default_observations,
        l_value=l_value,
        max_size=max_size,
    )

    return repertoire.add(  # type: ignore
        genotypes, descriptors, fitnesses, observations
    )

get_cells_indices(batch_of_descriptors, centroids, k_nn)

Returns the array of cells indices for a batch of descriptors given the centroids of the grid.

Parameters:
  • batch_of_descriptors (Descriptor) – a batch of descriptors of shape (batch_size, num_descriptors)

  • centroids (Centroid) – centroids array of shape (num_centroids, num_descriptors)

Returns:
  • Tuple[jnp.ndarray, jnp.ndarray] – the indices of the centroids corresponding to each vector of descriptors in the batch with shape (batch_size,)

Source code in qdax/core/containers/unstructured_repertoire.py
@partial(jax.jit, static_argnames=("k_nn",))
def get_cells_indices(
    batch_of_descriptors: Descriptor, centroids: Centroid, k_nn: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Returns the array of cells indices for a batch of descriptors
    given the centroids of the grid.

    Args:
        batch_of_descriptors: a batch of descriptors
            of shape (batch_size, num_descriptors)
        centroids: centroids array of shape (num_centroids, num_descriptors)

    Returns:
        the indices of the centroids corresponding to each vector of descriptors
            in the batch with shape (batch_size,)
    """

    def _get_cells_indices(
        _descriptors: jnp.ndarray,
        _centroids: jnp.ndarray,
        _k_nn: int,
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        """Inner function.

        descriptors of shape (1, num_descriptors)
        centroids of shape (num_centroids, num_descriptors)
        """

        distances = jax.vmap(jnp.linalg.norm)(_descriptors - _centroids)

        # Negating distances because we want the smallest ones
        min_dist, min_args = jax.lax.top_k(-1 * distances, _k_nn)

        return min_args, -1 * min_dist

    func = jax.vmap(
        _get_cells_indices,
        in_axes=(
            0,
            None,
            None,
        ),
    )

    return func(batch_of_descriptors, centroids, k_nn)  # type: ignore

intra_batch_comp(normed, current_index, normed_all, eval_scores, l_value)

Function to know if an individual should be kept or not.

Source code in qdax/core/containers/unstructured_repertoire.py
@jax.jit
def intra_batch_comp(
    normed: jnp.ndarray,
    current_index: jnp.ndarray,
    normed_all: jnp.ndarray,
    eval_scores: jnp.ndarray,
    l_value: jnp.ndarray,
) -> jnp.ndarray:
    """Function to know if an individual should be kept or not."""

    # Check for individuals that are Nans, we remove them at the end
    not_existent = jnp.where((jnp.isnan(normed)).any(), True, False)

    # Fill in Nans to do computations
    normed = jnp.where(jnp.isnan(normed), jnp.full(normed.shape[-1], jnp.inf), normed)
    eval_scores = jnp.where(
        jnp.isinf(eval_scores), jnp.full(eval_scores.shape[-1], jnp.nan), eval_scores
    )

    # If we do not use a fitness (i.e same fitness everywhere), we create a virtual
    # fitness function to add individuals with the same bd
    additional_score = jnp.where(
        jnp.nanmax(eval_scores) == jnp.nanmin(eval_scores), 1.0, 0.0
    )
    additional_scores = jnp.linspace(0.0, additional_score, num=eval_scores.shape[0])

    # Add scores to empty individuals
    eval_scores = jnp.where(
        jnp.isnan(eval_scores), jnp.full(eval_scores.shape[0], -jnp.inf), eval_scores
    )
    # Virtual eval_scores
    eval_scores = eval_scores + additional_scores

    # For each point we check what other points are the closest ones.
    knn_relevant_scores, knn_relevant_indices = jax.lax.top_k(
        -1 * jax.vmap(jnp.linalg.norm)(normed - normed_all), eval_scores.shape[0]
    )
    # We negated the scores to use top_k so we reverse it.
    knn_relevant_scores = knn_relevant_scores * -1

    # Check if the individual is close enough to compare (under l-value)
    fitness = jnp.where(jnp.squeeze(knn_relevant_scores < l_value), True, False)

    # We want to eliminate the same individual (distance 0)
    fitness = jnp.where(knn_relevant_indices == current_index, False, fitness)
    current_fitness = jnp.squeeze(
        eval_scores.at[knn_relevant_indices.at[0].get()].get()
    )

    # Is the fitness of the other individual higher?
    # If both are True then we discard the current individual since this individual
    # would be replaced by the better one.
    discard_indiv = jnp.logical_and(
        jnp.where(
            eval_scores.at[knn_relevant_indices].get() > current_fitness, True, False
        ),
        fitness,
    ).any()

    # Discard Individuals with Nans as their BD (mainly for the readdition where we
    # have NaN bds)
    discard_indiv = jnp.logical_or(discard_indiv, not_existent)

    # Negate to know if we keep the individual
    return jnp.logical_not(discard_indiv)