Emitters

qdax.core.emitters special

cma_emitter

CMAEmitterState (EmitterState) dataclass

Emitter state for the CMA-ME emitter.

Parameters:
  • random_key (RNGKey) – a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future.

  • cmaes_state (CMAESState) – state of the underlying CMA-ES algorithm

  • previous_fitnesses (Fitness) – store last fitnesses of the repertoire. Used to compute the improvment.

  • emit_count (int) – count the number of emission events.

Source code in qdax/core/emitters/cma_emitter.py
class CMAEmitterState(EmitterState):
    """
    Emitter state for the CMA-ME emitter.

    Args:
        random_key: a random key to handle stochastic operations. Used for
            state update only, another key is used to emit. This might be
            subject to refactoring discussions in the future.
        cmaes_state: state of the underlying CMA-ES algorithm
        previous_fitnesses: store last fitnesses of the repertoire. Used to
            compute the improvment.
        emit_count: count the number of emission events.
    """

    random_key: RNGKey
    cmaes_state: CMAESState
    previous_fitnesses: Fitness
    emit_count: int
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/cma_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

CMAEmitter (Emitter, ABC)

Source code in qdax/core/emitters/cma_emitter.py
class CMAEmitter(Emitter, ABC):
    def __init__(
        self,
        batch_size: int,
        genotype_dim: int,
        centroids: Centroid,
        sigma_g: float,
        min_count: Optional[int] = None,
        max_count: Optional[float] = None,
    ):
        """
        Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
        Rapid Illumination of Behavior Space" by Fontaine et al.

        Args:
            batch_size: number of solutions sampled at each iteration
            genotype_dim: dimension of the genotype space.
            centroids: centroids used for the repertoire.
            sigma_g: standard deviation for the coefficients - called step size.
            min_count: minimum number of CMAES opt step before being considered for
                reinitialisation.
            max_count: maximum number of CMAES opt step authorized.
        """
        self._batch_size = batch_size

        # define a CMAES instance
        self._cmaes = CMAES(
            population_size=batch_size,
            search_dim=genotype_dim,
            # no need for fitness function in that specific case
            fitness_function=None,  # type: ignore
            num_best=batch_size,
            init_sigma=sigma_g,
            mean_init=None,  # will be init at zeros in cmaes
            bias_weights=True,
            delay_eigen_decomposition=True,
        )

        # minimum number of emitted solution before an emitter can be re-initialized
        if min_count is None:
            min_count = 0

        self._min_count = min_count

        if max_count is None:
            max_count = jnp.inf

        self._max_count = max_count

        self._centroids = centroids

        self._cma_initial_state = self._cmaes.init()

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[CMAEmitterState, RNGKey]:
        """
        Initializes the CMA-MEGA emitter


        Args:
            init_genotypes: initial genotypes to add to the grid.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        # return the initial state
        random_key, subkey = jax.random.split(random_key)
        return (
            CMAEmitterState(
                random_key=subkey,
                cmaes_state=self._cma_initial_state,
                previous_fitnesses=default_fitnesses,
                emit_count=0,
            ),
            random_key,
        )

    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        Emits new individuals. Interestingly, this method does not directly modifies
        individuals from the repertoire but sample from a distribution. Hence the
        repertoire is not used in the emit function.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            random_key: a random key to handle random operations.

        Returns:
            New genotypes and a new random key.
        """
        # emit from CMA-ES
        offsprings, random_key = self._cmaes.sample(
            cmaes_state=emitter_state.cmaes_state, random_key=random_key
        )

        return offsprings, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """
        Updates the CMA-ME emitter state.

        Note: we use the update_state function from CMAES, a function that assumes
        that the candidates are already sorted. We do this because we have to sort
        them in this function anyway, in order to apply the right weights to the
        terms when update theta.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """

        # retrieve elements from the emitter state
        cmaes_state = emitter_state.cmaes_state

        # Compute the improvements - needed for re-init condition
        indices = get_cells_indices(descriptors, repertoire.centroids)
        improvements = fitnesses - emitter_state.previous_fitnesses[indices]

        ranking_criteria = self._ranking_criteria(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
            improvements=improvements,
        )

        # get the indices
        sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

        # sort the candidates
        sorted_candidates = jax.tree_util.tree_map(
            lambda x: x[sorted_indices], genotypes
        )
        sorted_improvements = improvements[sorted_indices]

        # compute reinitialize condition
        emit_count = emitter_state.emit_count + 1

        # check if the criteria are too similar
        sorted_criteria = ranking_criteria[sorted_indices]
        flat_criteria_condition = (
            jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
        )

        # check all conditions
        reinitialize = (
            jnp.all(improvements < 0) * (emit_count > self._min_count)
            + (emit_count > self._max_count)
            + self._cmaes.stop_condition(cmaes_state)
            + flat_criteria_condition
        )

        # If true, draw randomly and re-initialize parameters
        def update_and_reinit(
            operand: Tuple[
                CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
            ],
        ) -> Tuple[CMAEmitterState, RNGKey]:
            return self._update_and_init_emitter_state(*operand)

        def update_wo_reinit(
            operand: Tuple[
                CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
            ],
        ) -> Tuple[CMAEmitterState, RNGKey]:
            """Update the emitter when no reinit event happened.

            Here lies a divergence compared to the original implementation. We
            are getting better results when using no mask and doing the update
            with the whole batch of individuals rather than keeping only the one
            than were added to the archive.

            Interestingly, keeping the best half was not doing better. We think that
            this might be due to the small batch size used.

            This applies for the setting from the paper CMA-ME. Those facts might
            not be true with other problems and hyperparameters.

            To replicate the code described in the paper, replace:
            `mask = jnp.ones_like(sorted_improvements)`

            by:
            ```
            mask = sorted_improvements >= 0
            mask = mask + 1e-6
            ```

            RMQ: the addition of 1e-6 is here to fix a numerical
            instability.
            """

            (cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand

            # Update CMA Parameters
            mask = jnp.ones_like(sorted_improvements)

            cmaes_state = self._cmaes.update_state_with_mask(
                cmaes_state, sorted_candidates, mask=mask
            )

            emitter_state = emitter_state.replace(
                cmaes_state=cmaes_state,
                emit_count=emit_count,
            )

            return emitter_state, random_key

        # Update CMA Parameters
        emitter_state, random_key = jax.lax.cond(
            reinitialize,
            update_and_reinit,
            update_wo_reinit,
            operand=(
                cmaes_state,
                emitter_state,
                repertoire,
                emit_count,
                emitter_state.random_key,
            ),
        )

        # update the emitter state
        emitter_state = emitter_state.replace(
            random_key=random_key, previous_fitnesses=repertoire.fitnesses
        )

        return emitter_state

    def _update_and_init_emitter_state(
        self,
        cmaes_state: CMAESState,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        emit_count: int,
        random_key: RNGKey,
    ) -> Tuple[CMAEmitterState, RNGKey]:
        """Update the emitter state in the case of a reinit event.
        Reinit the cmaes state and use an individual from the repertoire
        as the starting mean.

        Args:
            cmaes_state: current cmaes state
            emitter_state: current cmame state
            repertoire: most recent repertoire
            emit_count: counter of the emitter
            random_key: key to handle stochastic events

        Returns:
            The updated emitter state.
        """

        # re-sample
        random_genotype, random_key = repertoire.sample(random_key, 1)

        # remove the batch dim
        new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)

        cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_init_state, emit_count=0
        )

        return emitter_state, random_key

    @abstractmethod
    def _ranking_criteria(
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jnp.ndarray,
    ) -> jnp.ndarray:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvments of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it's the improvement, or the fitness when the cell was previously
            unoccupied. Additionally, genotypes that discovered a new cell are
            given on offset to be ranked in front of other genotypes.
        """

        pass
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

__init__(self, batch_size, genotype_dim, centroids, sigma_g, min_count=None, max_count=None) special

Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Behavior Space" by Fontaine et al.

Parameters:
  • batch_size (int) – number of solutions sampled at each iteration

  • genotype_dim (int) – dimension of the genotype space.

  • centroids (Centroid) – centroids used for the repertoire.

  • sigma_g (float) – standard deviation for the coefficients - called step size.

  • min_count (Optional[int]) – minimum number of CMAES opt step before being considered for reinitialisation.

  • max_count (Optional[float]) – maximum number of CMAES opt step authorized.

Source code in qdax/core/emitters/cma_emitter.py
def __init__(
    self,
    batch_size: int,
    genotype_dim: int,
    centroids: Centroid,
    sigma_g: float,
    min_count: Optional[int] = None,
    max_count: Optional[float] = None,
):
    """
    Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
    Rapid Illumination of Behavior Space" by Fontaine et al.

    Args:
        batch_size: number of solutions sampled at each iteration
        genotype_dim: dimension of the genotype space.
        centroids: centroids used for the repertoire.
        sigma_g: standard deviation for the coefficients - called step size.
        min_count: minimum number of CMAES opt step before being considered for
            reinitialisation.
        max_count: maximum number of CMAES opt step authorized.
    """
    self._batch_size = batch_size

    # define a CMAES instance
    self._cmaes = CMAES(
        population_size=batch_size,
        search_dim=genotype_dim,
        # no need for fitness function in that specific case
        fitness_function=None,  # type: ignore
        num_best=batch_size,
        init_sigma=sigma_g,
        mean_init=None,  # will be init at zeros in cmaes
        bias_weights=True,
        delay_eigen_decomposition=True,
    )

    # minimum number of emitted solution before an emitter can be re-initialized
    if min_count is None:
        min_count = 0

    self._min_count = min_count

    if max_count is None:
        max_count = jnp.inf

    self._max_count = max_count

    self._centroids = centroids

    self._cma_initial_state = self._cmaes.init()
init(self, init_genotypes, random_key)

Initializes the CMA-MEGA emitter

Parameters:
  • init_genotypes (Genotype) – initial genotypes to add to the grid.

  • random_key (RNGKey) – a random key to handle stochastic operations.

Returns:
  • Tuple[CMAEmitterState, RNGKey] – The initial state of the emitter.

Source code in qdax/core/emitters/cma_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAEmitterState, RNGKey]:
    """
    Initializes the CMA-MEGA emitter


    Args:
        init_genotypes: initial genotypes to add to the grid.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

    # return the initial state
    random_key, subkey = jax.random.split(random_key)
    return (
        CMAEmitterState(
            random_key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
            emit_count=0,
        ),
        random_key,
    )
emit(self, repertoire, emitter_state, random_key)

Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) – a repertoire of genotypes (unused).

  • emitter_state (CMAEmitterState) – the state of the CMA-MEGA emitter.

  • random_key (RNGKey) – a random key to handle random operations.

Returns:
  • Tuple[Genotype, RNGKey] – New genotypes and a new random key.

Source code in qdax/core/emitters/cma_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    Emits new individuals. Interestingly, this method does not directly modifies
    individuals from the repertoire but sample from a distribution. Hence the
    repertoire is not used in the emit function.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        random_key: a random key to handle random operations.

    Returns:
        New genotypes and a new random key.
    """
    # emit from CMA-ES
    offsprings, random_key = self._cmaes.sample(
        cmaes_state=emitter_state.cmaes_state, random_key=random_key
    )

    return offsprings, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the CMA-ME emitter state.

Note: we use the update_state function from CMAES, a function that assumes that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.

Parameters:
  • emitter_state (CMAEmitterState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Genotype) – the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) – the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) – the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores]) – unused

Returns:
  • Optional[EmitterState] – The updated emitter state.

Source code in qdax/core/emitters/cma_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: CMAEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """
    Updates the CMA-ME emitter state.

    Note: we use the update_state function from CMAES, a function that assumes
    that the candidates are already sorted. We do this because we have to sort
    them in this function anyway, in order to apply the right weights to the
    terms when update theta.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """

    # retrieve elements from the emitter state
    cmaes_state = emitter_state.cmaes_state

    # Compute the improvements - needed for re-init condition
    indices = get_cells_indices(descriptors, repertoire.centroids)
    improvements = fitnesses - emitter_state.previous_fitnesses[indices]

    ranking_criteria = self._ranking_criteria(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
        improvements=improvements,
    )

    # get the indices
    sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

    # sort the candidates
    sorted_candidates = jax.tree_util.tree_map(
        lambda x: x[sorted_indices], genotypes
    )
    sorted_improvements = improvements[sorted_indices]

    # compute reinitialize condition
    emit_count = emitter_state.emit_count + 1

    # check if the criteria are too similar
    sorted_criteria = ranking_criteria[sorted_indices]
    flat_criteria_condition = (
        jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
    )

    # check all conditions
    reinitialize = (
        jnp.all(improvements < 0) * (emit_count > self._min_count)
        + (emit_count > self._max_count)
        + self._cmaes.stop_condition(cmaes_state)
        + flat_criteria_condition
    )

    # If true, draw randomly and re-initialize parameters
    def update_and_reinit(
        operand: Tuple[
            CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
        ],
    ) -> Tuple[CMAEmitterState, RNGKey]:
        return self._update_and_init_emitter_state(*operand)

    def update_wo_reinit(
        operand: Tuple[
            CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
        ],
    ) -> Tuple[CMAEmitterState, RNGKey]:
        """Update the emitter when no reinit event happened.

        Here lies a divergence compared to the original implementation. We
        are getting better results when using no mask and doing the update
        with the whole batch of individuals rather than keeping only the one
        than were added to the archive.

        Interestingly, keeping the best half was not doing better. We think that
        this might be due to the small batch size used.

        This applies for the setting from the paper CMA-ME. Those facts might
        not be true with other problems and hyperparameters.

        To replicate the code described in the paper, replace:
        `mask = jnp.ones_like(sorted_improvements)`

        by:
        ```
        mask = sorted_improvements >= 0
        mask = mask + 1e-6
        ```

        RMQ: the addition of 1e-6 is here to fix a numerical
        instability.
        """

        (cmaes_state, emitter_state, repertoire, emit_count, random_key) = operand

        # Update CMA Parameters
        mask = jnp.ones_like(sorted_improvements)

        cmaes_state = self._cmaes.update_state_with_mask(
            cmaes_state, sorted_candidates, mask=mask
        )

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_state,
            emit_count=emit_count,
        )

        return emitter_state, random_key

    # Update CMA Parameters
    emitter_state, random_key = jax.lax.cond(
        reinitialize,
        update_and_reinit,
        update_wo_reinit,
        operand=(
            cmaes_state,
            emitter_state,
            repertoire,
            emit_count,
            emitter_state.random_key,
        ),
    )

    # update the emitter state
    emitter_state = emitter_state.replace(
        random_key=random_key, previous_fitnesses=repertoire.fitnesses
    )

    return emitter_state

cma_improvement_emitter

CMAImprovementEmitter (CMAEmitter)

Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Behavior Space" by Fontaine et al.

This class implements the improvement emitter, where the update of the distribution is biased towards solution that improve the QD score.

Parameters:
  • batch_size – number of solutions sampled at each iteration

  • genotype_dim – dimension of the genotype space.

  • centroids – centroids used for the repertoire.

  • sigma_g – standard deviation for the coefficients - called step size.

  • min_count – minimum number of CMAES opt step before being considered for reinitialisation.

  • max_count – maximum number of CMAES opt step authorized.

Source code in qdax/core/emitters/cma_improvement_emitter.py
class CMAImprovementEmitter(CMAEmitter):
    """Class for the emitter of CMA ME from "Covariance Matrix Adaptation
    for the Rapid Illumination of Behavior Space" by Fontaine et al.

    This class implements the improvement emitter, where the update of the
    distribution is biased towards solution that improve the QD score.

    Args:
        batch_size: number of solutions sampled at each iteration
        genotype_dim: dimension of the genotype space.
        centroids: centroids used for the repertoire.
        sigma_g: standard deviation for the coefficients - called step size.
        min_count: minimum number of CMAES opt step before being considered for
            reinitialisation.
        max_count: maximum number of CMAES opt step authorized.
    """

    def _ranking_criteria(
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jnp.ndarray,
    ) -> jnp.ndarray:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvments of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it's the improvement, or the fitness when the cell was previously
            unoccupied. Additionally, genotypes that discovered a new cell are
            given on offset to be ranked in front of other genotypes.
        """

        # condition for being a new cell
        condition = improvements == jnp.inf

        # criteria: fitness if new cell, improvement else
        ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)

        # make sure to have all the new cells first
        new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

        ranking_criteria = jnp.where(
            condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
        )

        return ranking_criteria  # type: ignore

cma_mega_emitter

CMAMEGAState (EmitterState) dataclass

Emitter state for the CMA-MEGA emitter.

Parameters:
  • theta (Genotype) – current genotype from where candidates will be drawn.

  • theta_grads (Gradient) – normalized fitness and descriptors gradients of theta.

  • random_key (RNGKey) – a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future.

  • cmaes_state (CMAESState) – state of the underlying CMA-ES algorithm

  • previous_fitnesses (Fitness) – store last fitnesses of the repertoire. Used to compute the improvment.

Source code in qdax/core/emitters/cma_mega_emitter.py
class CMAMEGAState(EmitterState):
    """
    Emitter state for the CMA-MEGA emitter.

    Args:
        theta: current genotype from where candidates will be drawn.
        theta_grads: normalized fitness and descriptors gradients of theta.
        random_key: a random key to handle stochastic operations. Used for
            state update only, another key is used to emit. This might be
            subject to refactoring discussions in the future.
        cmaes_state: state of the underlying CMA-ES algorithm
        previous_fitnesses: store last fitnesses of the repertoire. Used to
            compute the improvment.
    """

    theta: Genotype
    theta_grads: Gradient
    random_key: RNGKey
    cmaes_state: CMAESState
    previous_fitnesses: Fitness
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/cma_mega_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

CMAMEGAEmitter (Emitter)

Source code in qdax/core/emitters/cma_mega_emitter.py
class CMAMEGAEmitter(Emitter):
    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        batch_size: int,
        learning_rate: float,
        num_descriptors: int,
        centroids: Centroid,
        sigma_g: float,
    ):
        """
        Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
        Fontaine et al.

        Args:
            scoring_function: a function to score individuals, outputing fitness,
                descriptors and extra scores. With this emitter, the extra score
                contains gradients and normalized gradients.
            batch_size: number of solutions sampled at each iteration
            learning_rate: rate at which the mean of the distribution is updated.
            num_descriptors: number of descriptors
            centroids: centroids of the repertoire used to store the genotypes
            sigma_g: standard deviation for the coefficients
        """

        self._scoring_function = scoring_function
        self._batch_size = batch_size
        self._learning_rate = learning_rate

        # weights used to update the gradient direction through a linear combination
        self._weights = jnp.expand_dims(
            jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
        )
        self._weights = self._weights / (self._weights.sum())

        # define a CMAES instance - used to update the coeffs
        self._cmaes = CMAES(
            population_size=batch_size,
            search_dim=num_descriptors + 1,
            # no need for fitness function in that specific case
            fitness_function=None,  # type: ignore
            num_best=batch_size,
            init_sigma=sigma_g,
            bias_weights=True,
            delay_eigen_decomposition=True,
        )

        self._centroids = centroids

        self._cma_initial_state = self._cmaes.init()

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[CMAMEGAState, RNGKey]:
        """
        Initializes the CMA-MEGA emitter.


        Args:
            init_genotypes: initial genotypes to add to the grid.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # define init theta as 0
        theta = jax.tree_util.tree_map(
            lambda x: jnp.zeros_like(x[:1, ...]),
            init_genotypes,
        )

        # score it
        _, _, extra_score, random_key = self._scoring_function(theta, random_key)
        theta_grads = extra_score["normalized_grads"]

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        # return the initial state
        random_key, subkey = jax.random.split(random_key)
        return (
            CMAMEGAState(
                theta=theta,
                theta_grads=theta_grads,
                random_key=subkey,
                cmaes_state=self._cma_initial_state,
                previous_fitnesses=default_fitnesses,
            ),
            random_key,
        )

    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAMEGAState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        Emits new individuals. Interestingly, this method does not directly modifies
        individuals from the repertoire but sample from a distribution. Hence the
        repertoire is not used in the emit function.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            random_key: a random key to handle random operations.

        Returns:
            New genotypes and a new random key.
        """

        # retrieve elements from the emitter state
        theta = jnp.nan_to_num(emitter_state.theta)
        cmaes_state = emitter_state.cmaes_state

        # get grads - remove nan and first dimension
        grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))

        # Draw random coefficients - use the emitter state key
        coeffs, random_key = self._cmaes.sample(
            cmaes_state=cmaes_state, random_key=emitter_state.random_key
        )

        # make sure the fitness coefficient is positive
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        update_grad = coeffs @ grads.T

        # Compute new candidates
        new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)

        return new_thetas, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: CMAMEGAState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """
        Updates the CMA-MEGA emitter state.

        Note: in order to recover the coeffs that where used to sample the genotypes,
        we reuse the emitter state's random key in this function.

        Note: we use the update_state function from CMAES, a function that suppose
        that the candidates are already sorted. We do this because we have to sort
        them in this function anyway, in order to apply the right weights to the
        terms when update theta.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """

        # retrieve elements from the emitter state
        cmaes_state = emitter_state.cmaes_state
        theta = jnp.nan_to_num(emitter_state.theta)
        grads = jnp.nan_to_num(emitter_state.theta_grads[0])

        # Update the archive and compute the improvements
        indices = get_cells_indices(descriptors, repertoire.centroids)
        improvements = fitnesses - emitter_state.previous_fitnesses[indices]

        # condition for being a new cell
        condition = improvements == jnp.inf

        # criteria: fitness if new cell, improvement else
        ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)

        # make sure to have all the new cells first
        new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

        ranking_criteria = jnp.where(
            condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
        )

        # sort indices according to the criteria
        sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

        # Draw the coeffs - reuse the emitter state key to get same coeffs
        coeffs, random_key = self._cmaes.sample(
            cmaes_state=cmaes_state, random_key=emitter_state.random_key
        )
        # make sure the fitness coeff is positive
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))

        # get the gradients that must be applied
        update_grad = coeffs @ grads.T

        # weight terms - based on improvement rank
        gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)

        # update theta
        theta = jax.tree_util.tree_map(
            lambda x, y: x + self._learning_rate * y, theta, gradient_step
        )

        # Update CMA Parameters
        sorted_candidates = coeffs[sorted_indices]
        cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)

        # If no improvement draw randomly and re-initialize parameters
        reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
            cmaes_state
        )

        # re-sample
        random_theta, random_key = repertoire.sample(random_key, 1)

        # update theta in case of reinit
        theta = jax.tree_util.tree_map(
            lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta
        )

        # update cmaes state in case of reinit
        cmaes_state = jax.tree_util.tree_map(
            lambda x, y: jnp.where(reinitialize, x=x, y=y),
            self._cma_initial_state,
            cmaes_state,
        )

        # score theta
        _, _, extra_score, random_key = self._scoring_function(theta, random_key)

        # create new emitter state
        emitter_state = CMAMEGAState(
            theta=theta,
            theta_grads=extra_score["normalized_grads"],
            random_key=random_key,
            cmaes_state=cmaes_state,
            previous_fitnesses=repertoire.fitnesses,
        )

        return emitter_state

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

__init__(self, scoring_function, batch_size, learning_rate, num_descriptors, centroids, sigma_g) special

Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by Fontaine et al.

Parameters:
  • scoring_function (Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]) – a function to score individuals, outputing fitness, descriptors and extra scores. With this emitter, the extra score contains gradients and normalized gradients.

  • batch_size (int) – number of solutions sampled at each iteration

  • learning_rate (float) – rate at which the mean of the distribution is updated.

  • num_descriptors (int) – number of descriptors

  • centroids (Centroid) – centroids of the repertoire used to store the genotypes

  • sigma_g (float) – standard deviation for the coefficients

Source code in qdax/core/emitters/cma_mega_emitter.py
def __init__(
    self,
    scoring_function: Callable[
        [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
    ],
    batch_size: int,
    learning_rate: float,
    num_descriptors: int,
    centroids: Centroid,
    sigma_g: float,
):
    """
    Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
    Fontaine et al.

    Args:
        scoring_function: a function to score individuals, outputing fitness,
            descriptors and extra scores. With this emitter, the extra score
            contains gradients and normalized gradients.
        batch_size: number of solutions sampled at each iteration
        learning_rate: rate at which the mean of the distribution is updated.
        num_descriptors: number of descriptors
        centroids: centroids of the repertoire used to store the genotypes
        sigma_g: standard deviation for the coefficients
    """

    self._scoring_function = scoring_function
    self._batch_size = batch_size
    self._learning_rate = learning_rate

    # weights used to update the gradient direction through a linear combination
    self._weights = jnp.expand_dims(
        jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
    )
    self._weights = self._weights / (self._weights.sum())

    # define a CMAES instance - used to update the coeffs
    self._cmaes = CMAES(
        population_size=batch_size,
        search_dim=num_descriptors + 1,
        # no need for fitness function in that specific case
        fitness_function=None,  # type: ignore
        num_best=batch_size,
        init_sigma=sigma_g,
        bias_weights=True,
        delay_eigen_decomposition=True,
    )

    self._centroids = centroids

    self._cma_initial_state = self._cmaes.init()
init(self, init_genotypes, random_key)

Initializes the CMA-MEGA emitter.

Parameters:
  • init_genotypes (Genotype) – initial genotypes to add to the grid.

  • random_key (RNGKey) – a random key to handle stochastic operations.

Returns:
  • Tuple[CMAMEGAState, RNGKey] – The initial state of the emitter.

Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAMEGAState, RNGKey]:
    """
    Initializes the CMA-MEGA emitter.


    Args:
        init_genotypes: initial genotypes to add to the grid.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # define init theta as 0
    theta = jax.tree_util.tree_map(
        lambda x: jnp.zeros_like(x[:1, ...]),
        init_genotypes,
    )

    # score it
    _, _, extra_score, random_key = self._scoring_function(theta, random_key)
    theta_grads = extra_score["normalized_grads"]

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

    # return the initial state
    random_key, subkey = jax.random.split(random_key)
    return (
        CMAMEGAState(
            theta=theta,
            theta_grads=theta_grads,
            random_key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
        ),
        random_key,
    )
emit(self, repertoire, emitter_state, random_key)

Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) – a repertoire of genotypes (unused).

  • emitter_state (CMAMEGAState) – the state of the CMA-MEGA emitter.

  • random_key (RNGKey) – a random key to handle random operations.

Returns:
  • Tuple[Genotype, RNGKey] – New genotypes and a new random key.

Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAMEGAState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    Emits new individuals. Interestingly, this method does not directly modifies
    individuals from the repertoire but sample from a distribution. Hence the
    repertoire is not used in the emit function.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        random_key: a random key to handle random operations.

    Returns:
        New genotypes and a new random key.
    """

    # retrieve elements from the emitter state
    theta = jnp.nan_to_num(emitter_state.theta)
    cmaes_state = emitter_state.cmaes_state

    # get grads - remove nan and first dimension
    grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))

    # Draw random coefficients - use the emitter state key
    coeffs, random_key = self._cmaes.sample(
        cmaes_state=cmaes_state, random_key=emitter_state.random_key
    )

    # make sure the fitness coefficient is positive
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
    update_grad = coeffs @ grads.T

    # Compute new candidates
    new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad)

    return new_thetas, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the CMA-MEGA emitter state.

Note: in order to recover the coeffs that where used to sample the genotypes, we reuse the emitter state's random key in this function.

Note: we use the update_state function from CMAES, a function that suppose that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.

Parameters:
  • emitter_state (CMAMEGAState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Genotype) – the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) – the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) – the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores]) – unused

Returns:
  • Optional[EmitterState] – The updated emitter state.

Source code in qdax/core/emitters/cma_mega_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: CMAMEGAState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """
    Updates the CMA-MEGA emitter state.

    Note: in order to recover the coeffs that where used to sample the genotypes,
    we reuse the emitter state's random key in this function.

    Note: we use the update_state function from CMAES, a function that suppose
    that the candidates are already sorted. We do this because we have to sort
    them in this function anyway, in order to apply the right weights to the
    terms when update theta.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """

    # retrieve elements from the emitter state
    cmaes_state = emitter_state.cmaes_state
    theta = jnp.nan_to_num(emitter_state.theta)
    grads = jnp.nan_to_num(emitter_state.theta_grads[0])

    # Update the archive and compute the improvements
    indices = get_cells_indices(descriptors, repertoire.centroids)
    improvements = fitnesses - emitter_state.previous_fitnesses[indices]

    # condition for being a new cell
    condition = improvements == jnp.inf

    # criteria: fitness if new cell, improvement else
    ranking_criteria = jnp.where(condition, x=fitnesses, y=improvements)

    # make sure to have all the new cells first
    new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

    ranking_criteria = jnp.where(
        condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
    )

    # sort indices according to the criteria
    sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

    # Draw the coeffs - reuse the emitter state key to get same coeffs
    coeffs, random_key = self._cmaes.sample(
        cmaes_state=cmaes_state, random_key=emitter_state.random_key
    )
    # make sure the fitness coeff is positive
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))

    # get the gradients that must be applied
    update_grad = coeffs @ grads.T

    # weight terms - based on improvement rank
    gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)

    # update theta
    theta = jax.tree_util.tree_map(
        lambda x, y: x + self._learning_rate * y, theta, gradient_step
    )

    # Update CMA Parameters
    sorted_candidates = coeffs[sorted_indices]
    cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)

    # If no improvement draw randomly and re-initialize parameters
    reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
        cmaes_state
    )

    # re-sample
    random_theta, random_key = repertoire.sample(random_key, 1)

    # update theta in case of reinit
    theta = jax.tree_util.tree_map(
        lambda x, y: jnp.where(reinitialize, x=x, y=y), random_theta, theta
    )

    # update cmaes state in case of reinit
    cmaes_state = jax.tree_util.tree_map(
        lambda x, y: jnp.where(reinitialize, x=x, y=y),
        self._cma_initial_state,
        cmaes_state,
    )

    # score theta
    _, _, extra_score, random_key = self._scoring_function(theta, random_key)

    # create new emitter state
    emitter_state = CMAMEGAState(
        theta=theta,
        theta_grads=extra_score["normalized_grads"],
        random_key=random_key,
        cmaes_state=cmaes_state,
        previous_fitnesses=repertoire.fitnesses,
    )

    return emitter_state

cma_pool_emitter

CMAPoolEmitterState (EmitterState) dataclass

Emitter state for the pool of CMA emitters.

This is for a pool of homogeneous emitters.

Parameters:
  • current_index (int) – the index of the current emitter state used.

  • emitter_states (CMAEmitterState) – the batch of emitter states currently used.

Source code in qdax/core/emitters/cma_pool_emitter.py
class CMAPoolEmitterState(EmitterState):
    """
    Emitter state for the pool of CMA emitters.

    This is for a pool of homogeneous emitters.

    Args:
        current_index: the index of the current emitter state used.
        emitter_states: the batch of emitter states currently used.
    """

    current_index: int
    emitter_states: CMAEmitterState
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/cma_pool_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

CMAPoolEmitter (Emitter)

Source code in qdax/core/emitters/cma_pool_emitter.py
class CMAPoolEmitter(Emitter):
    def __init__(self, num_states: int, emitter: CMAEmitter):
        """Instantiate a pool of homogeneous emitters.

        Args:
            num_states: the number of emitters to consider. We can use a
                single emitter object and a batched emitter state.
            emitter: the type of emitter for the pool.
        """
        self._num_states = num_states
        self._emitter = emitter

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._emitter.batch_size

    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[CMAPoolEmitterState, RNGKey]:
        """
        Initializes the CMA-MEGA emitter


        Args:
            init_genotypes: initial genotypes to add to the grid.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        def scan_emitter_init(
            carry: RNGKey, unused: Any
        ) -> Tuple[RNGKey, CMAEmitterState]:
            random_key = carry
            emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
            return random_key, emitter_state

        # init all the emitter states
        random_key, emitter_states = jax.lax.scan(
            scan_emitter_init, random_key, (), length=self._num_states
        )

        # define the emitter state of the pool
        emitter_state = CMAPoolEmitterState(
            current_index=0, emitter_states=emitter_states
        )

        return (
            emitter_state,
            random_key,
        )

    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAPoolEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        Emits new individuals.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            random_key: a random key to handle random operations.

        Returns:
            New genotypes and a new random key.
        """

        # retrieve the relevant emitter state
        current_index = emitter_state.current_index
        used_emitter_state = jax.tree_util.tree_map(
            lambda x: x[current_index], emitter_state.emitter_states
        )

        # use it to emit offsprings
        offsprings, random_key = self._emitter.emit(
            repertoire, used_emitter_state, random_key
        )

        return offsprings, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: CMAPoolEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """
        Updates the emitter state.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """

        # retrieve the emitter that has been used and it's emitter state
        current_index = emitter_state.current_index
        emitter_states = emitter_state.emitter_states

        used_emitter_state = jax.tree_util.tree_map(
            lambda x: x[current_index], emitter_states
        )

        # update the used emitter state
        used_emitter_state = self._emitter.state_update(
            emitter_state=used_emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the emitter state
        emitter_states = jax.tree_util.tree_map(
            lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
        )

        # determine the next emitter to be used
        emit_counts = emitter_states.emit_count

        new_index = jnp.argmin(emit_counts)

        emitter_state = emitter_state.replace(
            current_index=new_index, emitter_states=emitter_states
        )

        return emitter_state  # type: ignore
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

__init__(self, num_states, emitter) special

Instantiate a pool of homogeneous emitters.

Parameters:
  • num_states (int) – the number of emitters to consider. We can use a single emitter object and a batched emitter state.

  • emitter (CMAEmitter) – the type of emitter for the pool.

Source code in qdax/core/emitters/cma_pool_emitter.py
def __init__(self, num_states: int, emitter: CMAEmitter):
    """Instantiate a pool of homogeneous emitters.

    Args:
        num_states: the number of emitters to consider. We can use a
            single emitter object and a batched emitter state.
        emitter: the type of emitter for the pool.
    """
    self._num_states = num_states
    self._emitter = emitter
init(self, init_genotypes, random_key)

Initializes the CMA-MEGA emitter

Parameters:
  • init_genotypes (Genotype) – initial genotypes to add to the grid.

  • random_key (RNGKey) – a random key to handle stochastic operations.

Returns:
  • Tuple[CMAPoolEmitterState, RNGKey] – The initial state of the emitter.

Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMAPoolEmitterState, RNGKey]:
    """
    Initializes the CMA-MEGA emitter


    Args:
        init_genotypes: initial genotypes to add to the grid.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    def scan_emitter_init(
        carry: RNGKey, unused: Any
    ) -> Tuple[RNGKey, CMAEmitterState]:
        random_key = carry
        emitter_state, random_key = self._emitter.init(init_genotypes, random_key)
        return random_key, emitter_state

    # init all the emitter states
    random_key, emitter_states = jax.lax.scan(
        scan_emitter_init, random_key, (), length=self._num_states
    )

    # define the emitter state of the pool
    emitter_state = CMAPoolEmitterState(
        current_index=0, emitter_states=emitter_states
    )

    return (
        emitter_state,
        random_key,
    )
emit(self, repertoire, emitter_state, random_key)

Emits new individuals.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) – a repertoire of genotypes (unused).

  • emitter_state (CMAPoolEmitterState) – the state of the CMA-MEGA emitter.

  • random_key (RNGKey) – a random key to handle random operations.

Returns:
  • Tuple[Genotype, RNGKey] – New genotypes and a new random key.

Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAPoolEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    Emits new individuals.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        random_key: a random key to handle random operations.

    Returns:
        New genotypes and a new random key.
    """

    # retrieve the relevant emitter state
    current_index = emitter_state.current_index
    used_emitter_state = jax.tree_util.tree_map(
        lambda x: x[current_index], emitter_state.emitter_states
    )

    # use it to emit offsprings
    offsprings, random_key = self._emitter.emit(
        repertoire, used_emitter_state, random_key
    )

    return offsprings, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the emitter state.

Parameters:
  • emitter_state (CMAPoolEmitterState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Genotype) – the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) – the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) – the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores]) – unused

Returns:
  • Optional[EmitterState] – The updated emitter state.

Source code in qdax/core/emitters/cma_pool_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: CMAPoolEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """
    Updates the emitter state.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """

    # retrieve the emitter that has been used and it's emitter state
    current_index = emitter_state.current_index
    emitter_states = emitter_state.emitter_states

    used_emitter_state = jax.tree_util.tree_map(
        lambda x: x[current_index], emitter_states
    )

    # update the used emitter state
    used_emitter_state = self._emitter.state_update(
        emitter_state=used_emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # update the emitter state
    emitter_states = jax.tree_util.tree_map(
        lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
    )

    # determine the next emitter to be used
    emit_counts = emitter_states.emit_count

    new_index = jnp.argmin(emit_counts)

    emitter_state = emitter_state.replace(
        current_index=new_index, emitter_states=emitter_states
    )

    return emitter_state  # type: ignore

cma_rnd_emitter

CMARndEmitterState (CMAEmitterState) dataclass

Emitter state for the CMA-ME random direction emitter.

Parameters:
  • random_key (RNGKey) – a random key to handle stochastic operations. Used for state update only, another key is used to emit. This might be subject to refactoring discussions in the future.

  • cmaes_state (CMAESState) – state of the underlying CMA-ES algorithm

  • previous_fitnesses (Fitness) – store last fitnesses of the repertoire. Used to compute the improvment.

  • emit_count (int) – count the number of emission events.

  • random_direction (Descriptor) – direction of the behavior space we are trying to explore.

Source code in qdax/core/emitters/cma_rnd_emitter.py
class CMARndEmitterState(CMAEmitterState):
    """
    Emitter state for the CMA-ME random direction emitter.


    Args:
        random_key: a random key to handle stochastic operations. Used for
            state update only, another key is used to emit. This might be
            subject to refactoring discussions in the future.
        cmaes_state: state of the underlying CMA-ES algorithm
        previous_fitnesses: store last fitnesses of the repertoire. Used to
            compute the improvment.
        emit_count: count the number of emission events.
        random_direction: direction of the behavior space we are trying to
            explore.
    """

    random_direction: Descriptor
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/cma_rnd_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

CMARndEmitter (CMAEmitter)

Source code in qdax/core/emitters/cma_rnd_emitter.py
class CMARndEmitter(CMAEmitter):
    @partial(jax.jit, static_argnames=("self",))
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[CMARndEmitterState, RNGKey]:
        """
        Initializes the CMA-MEGA emitter


        Args:
            init_genotypes: initial genotypes to add to the grid.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

        # take a random direction
        random_key, subkey = jax.random.split(random_key)
        random_direction = jax.random.uniform(
            subkey,
            shape=(self._centroids.shape[-1],),
        )

        # return the initial state
        random_key, subkey = jax.random.split(random_key)

        return (
            CMARndEmitterState(
                random_key=subkey,
                cmaes_state=self._cma_initial_state,
                previous_fitnesses=default_fitnesses,
                emit_count=0,
                random_direction=random_direction,
            ),
            random_key,
        )

    def _update_and_init_emitter_state(
        self,
        cmaes_state: CMAESState,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        emit_count: int,
        random_key: RNGKey,
    ) -> Tuple[CMAEmitterState, RNGKey]:
        """Update the emitter state in the case of a reinit event.
        Reinit the cmaes state and use an individual from the repertoire
        as the starting mean.

        Args:
            cmaes_state: current cmaes state
            emitter_state: current cmame state
            repertoire: most recent repertoire
            emit_count: counter of the emitter
            random_key: key to handle stochastic events

        Returns:
            The updated emitter state.
        """

        # re-sample
        random_genotype, random_key = repertoire.sample(random_key, 1)

        # get new mean - remove the batch dim
        new_mean = jax.tree_util.tree_map(lambda x: x.squeeze(0), random_genotype)

        # define the corresponding cmaes init state
        cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)

        # take a new random direction
        random_key, subkey = jax.random.split(random_key)
        random_direction = jax.random.uniform(
            subkey,
            shape=(self._centroids.shape[-1],),
        )

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_init_state,
            emit_count=0,
            random_direction=random_direction,
        )

        return emitter_state, random_key

    def _ranking_criteria(
        self,
        emitter_state: CMARndEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jnp.ndarray,
    ) -> jnp.ndarray:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvments of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it is the dot product of the descriptor with the current random
            direction.
        """

        # criteria: projection of the descriptors along the random direction
        ranking_criteria = jnp.dot(descriptors, emitter_state.random_direction)

        # make sure to have all the new cells first
        new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

        # condition for being a new cell
        condition = improvements == jnp.inf

        ranking_criteria = jnp.where(
            condition, x=ranking_criteria + new_cell_offset, y=ranking_criteria
        )

        return ranking_criteria  # type: ignore
init(self, init_genotypes, random_key)

Initializes the CMA-MEGA emitter

Parameters:
  • init_genotypes (Genotype) – initial genotypes to add to the grid.

  • random_key (RNGKey) – a random key to handle stochastic operations.

Returns:
  • Tuple[CMARndEmitterState, RNGKey] – The initial state of the emitter.

Source code in qdax/core/emitters/cma_rnd_emitter.py
@partial(jax.jit, static_argnames=("self",))
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[CMARndEmitterState, RNGKey]:
    """
    Initializes the CMA-MEGA emitter


    Args:
        init_genotypes: initial genotypes to add to the grid.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)

    # take a random direction
    random_key, subkey = jax.random.split(random_key)
    random_direction = jax.random.uniform(
        subkey,
        shape=(self._centroids.shape[-1],),
    )

    # return the initial state
    random_key, subkey = jax.random.split(random_key)

    return (
        CMARndEmitterState(
            random_key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
            emit_count=0,
            random_direction=random_direction,
        ),
        random_key,
    )

dpg_emitter

Implements the Diversity PG inspired by QDPG algorithm in jax for brax environments, based on: https://arxiv.org/abs/2006.08505

DiversityPGConfig (QualityPGConfig) dataclass

Configuration for DiversityPG Emitter

Source code in qdax/core/emitters/dpg_emitter.py
@dataclass
class DiversityPGConfig(QualityPGConfig):
    """Configuration for DiversityPG Emitter"""

    # inherits fields from QualityPGConfig

    # Archive params
    archive_acceptance_threshold: float = 0.1
    archive_max_size: int = 10000

DiversityPGEmitterState (QualityPGEmitterState) dataclass

Contains training state for the learner.

Source code in qdax/core/emitters/dpg_emitter.py
class DiversityPGEmitterState(QualityPGEmitterState):
    """Contains training state for the learner."""

    # inherits from QualityPGEmitterState

    archive: Archive
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/dpg_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

DiversityPGEmitter (QualityPGEmitter)

A diversity policy gradient emitter used to implement QDPG algorithm.

Please not that the inheritence between DiversityPGEmitter and QualityPGEmitter could be increased with changes in the way transitions samples are handled in the QualityPGEmitter. But this would modify the computation/memory strategy of the current implementation. Hence, we won't apply this yet and will discuss this with the development team.

Source code in qdax/core/emitters/dpg_emitter.py
class DiversityPGEmitter(QualityPGEmitter):
    """
    A diversity policy gradient emitter used to implement QDPG algorithm.

    Please not that the inheritence between DiversityPGEmitter and QualityPGEmitter
    could be increased with changes in the way transitions samples are handled in
    the QualityPGEmitter. But this would modify the computation/memory strategy of the
    current implementation. Hence, we won't apply this yet and will discuss this with
    the development team.
    """

    def __init__(
        self,
        config: DiversityPGConfig,
        policy_network: nn.Module,
        env: QDEnv,
        score_novelty: Callable[[Archive, StateDescriptor], Reward],
    ) -> None:

        # usual init operations from PGAME
        super().__init__(config, policy_network, env)

        self._config: DiversityPGConfig = config

        # define scoring function
        self._score_novelty = score_novelty

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[DiversityPGEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the PGAMEEmitter, a new random key.
        """

        # init elements of diversity emitter state with QualityEmitterState.init()
        diversity_emitter_state, random_key = super().init(init_genotypes, random_key)

        # store elements in a dictionary
        attributes_dict = vars(diversity_emitter_state)

        # init archive
        archive = Archive.create(
            acceptance_threshold=self._config.archive_acceptance_threshold,
            state_descriptor_size=self._env.state_descriptor_length,
            max_size=self._config.archive_max_size,
        )

        # init emitter state
        emitter_state = DiversityPGEmitterState(
            # retrieve all attributes from the QualityPGEmitterState
            **attributes_dict,
            # add the last element: archive
            archive=archive,
        )

        return emitter_state, random_key

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: DiversityPGEmitterState,
        repertoire: Optional[Repertoire],
        genotypes: Optional[Genotype],
        fitnesses: Optional[Fitness],
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> DiversityPGEmitterState:
        """This function gives an opportunity to update the emitter state
        after the genotypes have been scored.

        Here it is used to fill the Replay Buffer with the transitions
        from the scoring of the genotypes, and then the training of the
        critic/actor happens. Hence the params of critic/actor are updated,
        as well as their optimizer states.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # get the transitions out of the dictionary
        assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
        transitions = extra_scores["transitions"]

        # add transitions in the replay buffer
        replay_buffer = emitter_state.replay_buffer.insert(transitions)
        emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

        archive = emitter_state.archive.insert(transitions.state_desc)

        def scan_train_critics(
            carry: DiversityPGEmitterState, transitions: QDTransition
        ) -> Tuple[DiversityPGEmitterState, Any]:
            emitter_state = carry
            new_emitter_state = self._train_critics(emitter_state, transitions)
            return new_emitter_state, ()

        # sample transitions
        (transitions, random_key,) = emitter_state.replay_buffer.sample(
            random_key=emitter_state.random_key,
            sample_size=self._config.num_critic_training_steps
            * self._config.batch_size,
        )

        # update the rewards - diversity rewards
        state_descriptors = transitions.state_desc
        diversity_rewards = self._score_novelty(archive, state_descriptors)
        transitions = transitions.replace(rewards=diversity_rewards)

        # reshape the transitions
        transitions = jax.tree_util.tree_map(
            lambda x: x.reshape(
                (
                    self._config.num_critic_training_steps,
                    self._config.batch_size,
                )
                + x.shape[1:]
            ),
            transitions,
        )

        # Train critics and greedy actor
        emitter_state, _ = jax.lax.scan(
            scan_train_critics,
            emitter_state,
            (transitions),
            length=self._config.num_critic_training_steps,
        )

        emitter_state = emitter_state.replace(archive=archive)

        return emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _train_critics(
        self, emitter_state: DiversityPGEmitterState, transitions: QDTransition
    ) -> DiversityPGEmitterState:
        """Apply one gradient step to critics and to the greedy actor
        (contained in carry in training_state), then soft update target critics
        and target greedy actor.

        Those updates are very similar to those made in TD3.

        Args:
            emitter_state: actual emitter state

        Returns:
            New emitter state where the critic and the greedy actor have been
            updated. Optimizer states have also been updated in the process.
        """

        # Update Critic
        (
            critic_optimizer_state,
            critic_params,
            target_critic_params,
            random_key,
        ) = self._update_critic(
            critic_params=emitter_state.critic_params,
            target_critic_params=emitter_state.target_critic_params,
            target_actor_params=emitter_state.target_actor_params,
            critic_optimizer_state=emitter_state.critic_optimizer_state,
            transitions=transitions,
            random_key=emitter_state.random_key,
        )

        # Update greedy policy
        (policy_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond(
            emitter_state.steps % self._config.policy_delay == 0,
            lambda x: self._update_actor(*x),
            lambda _: (
                emitter_state.actor_opt_state,
                emitter_state.actor_params,
                emitter_state.target_actor_params,
            ),
            operand=(
                emitter_state.actor_params,
                emitter_state.actor_opt_state,
                emitter_state.target_actor_params,
                emitter_state.critic_params,
                transitions,
            ),
        )

        # Create new training state
        new_emitter_state = emitter_state.replace(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            actor_params=actor_params,
            actor_opt_state=policy_optimizer_state,
            target_critic_params=target_critic_params,
            target_actor_params=target_actor_params,
            random_key=random_key,
            steps=emitter_state.steps + 1,
            replay_buffer=emitter_state.replay_buffer,
        )

        return new_emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _mutation_function_pg(
        self,
        policy_params: Genotype,
        emitter_state: DiversityPGEmitterState,
    ) -> Genotype:
        """Apply pg mutation to a policy via multiple steps of gradient descent.

        Args:
            policy_params: a policy, supposed to be a differentiable neural
                network.
            emitter_state: the current state of the emitter, containing among others,
                the replay buffer, the critic.

        Returns:
            the updated params of the neural network.
        """

        # Define new policy optimizer state
        policy_optimizer_state = self._policies_optimizer.init(policy_params)

        def scan_train_policy(
            carry: Tuple[DiversityPGEmitterState, Genotype, optax.OptState],
            transitions: QDTransition,
        ) -> Tuple[Tuple[DiversityPGEmitterState, Genotype, optax.OptState], Any]:
            emitter_state, policy_params, policy_optimizer_state = carry
            (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ) = self._train_policy(
                emitter_state,
                policy_params,
                policy_optimizer_state,
                transitions,
            )
            return (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ), ()

        # sample transitions
        transitions, _random_key = emitter_state.replay_buffer.sample(
            random_key=emitter_state.random_key,
            sample_size=self._config.num_pg_training_steps * self._config.batch_size,
        )

        # update the rewards - diversity rewards
        state_descriptors = transitions.state_desc
        diversity_rewards = self._score_novelty(
            emitter_state.archive, state_descriptors
        )
        transitions = transitions.replace(rewards=diversity_rewards)

        # reshape the transitions
        transitions = jax.tree_util.tree_map(
            lambda x: x.reshape(
                (
                    self._config.num_pg_training_steps,
                    self._config.batch_size,
                )
                + x.shape[1:]
            ),
            transitions,
        )

        (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan(
            scan_train_policy,
            (emitter_state, policy_params, policy_optimizer_state),
            (transitions),
            length=self._config.num_pg_training_steps,
        )

        return policy_params

    @partial(jax.jit, static_argnames=("self",))
    def _train_policy(
        self,
        emitter_state: DiversityPGEmitterState,
        policy_params: Params,
        policy_optimizer_state: optax.OptState,
        transitions: QDTransition,
    ) -> Tuple[DiversityPGEmitterState, Params, optax.OptState]:
        """Apply one gradient step to a policy (called policies_params).

        Args:
            emitter_state: current state of the emitter.
            policy_params: parameters corresponding to the weights and bias of
                the neural network that defines the policy.

        Returns:
            The new emitter state and new params of the NN.
        """

        # update policy
        policy_optimizer_state, policy_params = self._update_policy(
            critic_params=emitter_state.critic_params,
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            transitions=transitions,
        )

        return emitter_state, policy_params, policy_optimizer_state
init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The initial population.

  • random_key (Array) – A random key.

Returns:
  • Tuple[qdax.core.emitters.dpg_emitter.DiversityPGEmitterState, jax.Array] – The initial state of the PGAMEEmitter, a new random key.

Source code in qdax/core/emitters/dpg_emitter.py
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[DiversityPGEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the PGAMEEmitter, a new random key.
    """

    # init elements of diversity emitter state with QualityEmitterState.init()
    diversity_emitter_state, random_key = super().init(init_genotypes, random_key)

    # store elements in a dictionary
    attributes_dict = vars(diversity_emitter_state)

    # init archive
    archive = Archive.create(
        acceptance_threshold=self._config.archive_acceptance_threshold,
        state_descriptor_size=self._env.state_descriptor_length,
        max_size=self._config.archive_max_size,
    )

    # init emitter state
    emitter_state = DiversityPGEmitterState(
        # retrieve all attributes from the QualityPGEmitterState
        **attributes_dict,
        # add the last element: archive
        archive=archive,
    )

    return emitter_state, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

This function gives an opportunity to update the emitter state after the genotypes have been scored.

Here it is used to fill the Replay Buffer with the transitions from the scoring of the genotypes, and then the training of the critic/actor happens. Hence the params of critic/actor are updated, as well as their optimizer states.

Parameters:
  • emitter_state (DiversityPGEmitterState) – current emitter state.

  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – unused here - but compulsory in the signature.

  • fitnesses (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • descriptors (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – extra information coming from the scoring function, this contains the transitions added to the replay buffer.

Returns:
  • DiversityPGEmitterState – New emitter state where the replay buffer has been filled with the new experienced transitions.

Source code in qdax/core/emitters/dpg_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
    self,
    emitter_state: DiversityPGEmitterState,
    repertoire: Optional[Repertoire],
    genotypes: Optional[Genotype],
    fitnesses: Optional[Fitness],
    descriptors: Optional[Descriptor],
    extra_scores: ExtraScores,
) -> DiversityPGEmitterState:
    """This function gives an opportunity to update the emitter state
    after the genotypes have been scored.

    Here it is used to fill the Replay Buffer with the transitions
    from the scoring of the genotypes, and then the training of the
    critic/actor happens. Hence the params of critic/actor are updated,
    as well as their optimizer states.

    Args:
        emitter_state: current emitter state.
        repertoire: the current genotypes repertoire
        genotypes: unused here - but compulsory in the signature.
        fitnesses: unused here - but compulsory in the signature.
        descriptors: unused here - but compulsory in the signature.
        extra_scores: extra information coming from the scoring function,
            this contains the transitions added to the replay buffer.

    Returns:
        New emitter state where the replay buffer has been filled with
        the new experienced transitions.
    """
    # get the transitions out of the dictionary
    assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
    transitions = extra_scores["transitions"]

    # add transitions in the replay buffer
    replay_buffer = emitter_state.replay_buffer.insert(transitions)
    emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

    archive = emitter_state.archive.insert(transitions.state_desc)

    def scan_train_critics(
        carry: DiversityPGEmitterState, transitions: QDTransition
    ) -> Tuple[DiversityPGEmitterState, Any]:
        emitter_state = carry
        new_emitter_state = self._train_critics(emitter_state, transitions)
        return new_emitter_state, ()

    # sample transitions
    (transitions, random_key,) = emitter_state.replay_buffer.sample(
        random_key=emitter_state.random_key,
        sample_size=self._config.num_critic_training_steps
        * self._config.batch_size,
    )

    # update the rewards - diversity rewards
    state_descriptors = transitions.state_desc
    diversity_rewards = self._score_novelty(archive, state_descriptors)
    transitions = transitions.replace(rewards=diversity_rewards)

    # reshape the transitions
    transitions = jax.tree_util.tree_map(
        lambda x: x.reshape(
            (
                self._config.num_critic_training_steps,
                self._config.batch_size,
            )
            + x.shape[1:]
        ),
        transitions,
    )

    # Train critics and greedy actor
    emitter_state, _ = jax.lax.scan(
        scan_train_critics,
        emitter_state,
        (transitions),
        length=self._config.num_critic_training_steps,
    )

    emitter_state = emitter_state.replace(archive=archive)

    return emitter_state  # type: ignore

emitter

EmitterState (PyTreeNode) dataclass

The state of an emitter. Emitters are used to suggest offspring when evolving a population of genotypes. To emit new genotypes, some emitters need to have a state, that carries useful informations, like running means, distribution parameters, critics, replay buffers etc...

The object emitter state is used to store them and is updated along the process.

Parameters:
  • PyTreeNode – EmitterState base class inherits from PyTreeNode object from flax.struct package. It help registering objects as Pytree nodes automatically, and as the same benefits as classic Python @dataclass decorator.

Source code in qdax/core/emitters/emitter.py
class EmitterState(PyTreeNode):
    """The state of an emitter. Emitters are used to suggest offspring
    when evolving a population of genotypes. To emit new genotypes, some
    emitters need to have a state, that carries useful informations, like
    running means, distribution parameters, critics, replay buffers etc...

    The object emitter state is used to store them and is updated along
    the process.

    Args:
        PyTreeNode: EmitterState base class inherits from PyTreeNode object
            from flax.struct package. It help registering objects as Pytree
            nodes automatically, and as the same benefits as classic Python
            @dataclass decorator.
    """

    pass
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

Emitter (ABC)

Source code in qdax/core/emitters/emitter.py
class Emitter(ABC):
    def init(
        self, init_genotypes: Optional[Genotype], random_key: RNGKey
    ) -> Tuple[Optional[EmitterState], RNGKey]:
        """Initialises the state of the emitter. Some emitters do
        not need a state, in which case, the value None can be
        outputted.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state and a random key.
        """
        return None, random_key

    @abstractmethod
    def emit(
        self,
        repertoire: Optional[Repertoire],
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Function used to emit a population of offspring by any possible
        mean. New population can be sampled from a distribution or obtained
        through mutations of individuals sampled from the repertoire.


        Args:
            repertoire: a repertoire of genotypes.
            emitter_state: the state of the emitter.
            random_key: a random key to handle random operations.

        Returns:
            A batch of offspring, a new random key.
        """
        pass

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: Optional[EmitterState],
        repertoire: Optional[Repertoire] = None,
        genotypes: Optional[Genotype] = None,
        fitnesses: Optional[Fitness] = None,
        descriptors: Optional[Descriptor] = None,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """This function gives an opportunity to update the emitter state
        after the genotypes have been scored.

        As a matter of fact, many emitter states needs informations from
        the evaluations of the genotypes in order to be updated, for instance:
        - CMA emitter: to update the rank of the covariance matrix
        - PGA emitter: to fill the replay buffer and update the critic/greedy
            couple.

        This function does not need to be overridden. By default, it output
        the same emitter state.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """
        return emitter_state

    @property
    @abstractmethod
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        pass

    @property
    def use_all_data(self) -> bool:
        """Whether to use all data or not when used along other emitters.

        Used when an emitter is used in a multi emitter setting.

        Some emitter only the information from the genotypes they emitted when
        they update their state (for instance, the CMA emitters); but other use data
        from genotypes emitted by others (for instance, QualityPGEmitter and
        DiversityPGEmitter). The meta emitters like MultiEmitter need to know which
        data to give the sub emitter when udapting them. This property is used at
        this moment.

        Default behavior is to used only the data related to what was emitted.

        Returns:
            Whether to pass only the genotypes (and their evaluations) that the emitter
            emitted when updating it or all the genotypes emitted by all the emitters.
        """
        return False
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

use_all_data: bool property readonly

Whether to use all data or not when used along other emitters.

Used when an emitter is used in a multi emitter setting.

Some emitter only the information from the genotypes they emitted when they update their state (for instance, the CMA emitters); but other use data from genotypes emitted by others (for instance, QualityPGEmitter and DiversityPGEmitter). The meta emitters like MultiEmitter need to know which data to give the sub emitter when udapting them. This property is used at this moment.

Default behavior is to used only the data related to what was emitted.

Returns:
  • bool – Whether to pass only the genotypes (and their evaluations) that the emitter emitted when updating it or all the genotypes emitted by all the emitters.

init(self, init_genotypes, random_key)

Initialises the state of the emitter. Some emitters do not need a state, in which case, the value None can be outputted.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The genotypes of the initial population.

  • random_key (Array) – a random key to handle stochastic operations.

Returns:
  • Tuple[Optional[qdax.core.emitters.emitter.EmitterState], jax.Array] – The initial emitter state and a random key.

Source code in qdax/core/emitters/emitter.py
def init(
    self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
    """Initialises the state of the emitter. Some emitters do
    not need a state, in which case, the value None can be
    outputted.

    Args:
        init_genotypes: The genotypes of the initial population.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial emitter state and a random key.
    """
    return None, random_key
emit(self, repertoire, emitter_state, random_key)

Function used to emit a population of offspring by any possible mean. New population can be sampled from a distribution or obtained through mutations of individuals sampled from the repertoire.

Parameters:
  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – a repertoire of genotypes.

  • emitter_state (Optional[qdax.core.emitters.emitter.EmitterState]) – the state of the emitter.

  • random_key (Array) – a random key to handle random operations.

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A batch of offspring, a new random key.

Source code in qdax/core/emitters/emitter.py
@abstractmethod
def emit(
    self,
    repertoire: Optional[Repertoire],
    emitter_state: Optional[EmitterState],
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Function used to emit a population of offspring by any possible
    mean. New population can be sampled from a distribution or obtained
    through mutations of individuals sampled from the repertoire.


    Args:
        repertoire: a repertoire of genotypes.
        emitter_state: the state of the emitter.
        random_key: a random key to handle random operations.

    Returns:
        A batch of offspring, a new random key.
    """
    pass
state_update(self, emitter_state, repertoire=None, genotypes=None, fitnesses=None, descriptors=None, extra_scores=None)

This function gives an opportunity to update the emitter state after the genotypes have been scored.

As a matter of fact, many emitter states needs informations from the evaluations of the genotypes in order to be updated, for instance: - CMA emitter: to update the rank of the covariance matrix - PGA emitter: to fill the replay buffer and update the critic/greedy couple.

This function does not need to be overridden. By default, it output the same emitter state.

Parameters:
  • emitter_state (Optional[qdax.core.emitters.emitter.EmitterState]) – current emitter state

  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the genotypes of the batch of emitted offspring.

  • fitnesses (Optional[jax.Array]) – the fitnesses of the batch of emitted offspring.

  • descriptors (Optional[jax.Array]) – the descriptors of the emitted offspring.

  • extra_scores (Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]) – a dictionary with other values outputted by the scoring function.

Returns:
  • Optional[qdax.core.emitters.emitter.EmitterState] – The modified emitter state.

Source code in qdax/core/emitters/emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: Optional[EmitterState],
    repertoire: Optional[Repertoire] = None,
    genotypes: Optional[Genotype] = None,
    fitnesses: Optional[Fitness] = None,
    descriptors: Optional[Descriptor] = None,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """This function gives an opportunity to update the emitter state
    after the genotypes have been scored.

    As a matter of fact, many emitter states needs informations from
    the evaluations of the genotypes in order to be updated, for instance:
    - CMA emitter: to update the rank of the covariance matrix
    - PGA emitter: to fill the replay buffer and update the critic/greedy
        couple.

    This function does not need to be overridden. By default, it output
    the same emitter state.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring.
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: a dictionary with other values outputted by the
            scoring function.

    Returns:
        The modified emitter state.
    """
    return emitter_state

mees_emitter

Emitter and utils necessary to reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217

NoveltyArchive (PyTreeNode) dataclass

Novelty Archive used by the MAP-Elites-ES emitter.

Parameters:
  • archive (jnp.ndarray) – content of the archive

  • size (int) – total size of the archive

  • position (jnp.ndarray) – current position in the archive

Source code in qdax/core/emitters/mees_emitter.py
class NoveltyArchive(flax.struct.PyTreeNode):
    """Novelty Archive used by the MAP-Elites-ES emitter.

    Args:
        archive: content of the archive
        size: total size of the archive
        position: current position in the archive
    """

    archive: jnp.ndarray
    size: int = flax.struct.field(pytree_node=False)
    position: jnp.ndarray = flax.struct.field()

    @classmethod
    def init(
        cls,
        size: int,
        num_descriptors: int,
    ) -> NoveltyArchive:
        archive = jnp.zeros((size, num_descriptors))
        return cls(archive=archive, size=size, position=jnp.array(0, dtype=int))

    @jax.jit
    def update(
        self,
        descriptor: Descriptor,
    ) -> NoveltyArchive:
        """Update the content of the novelty archive with newly generated descriptor.

        Args:
            descriptor: new descriptor generated by MAP-Elites-ES
        Returns:
            The updated NoveltyArchive
        """

        new_archive = jax.lax.dynamic_update_slice_in_dim(
            self.archive,
            descriptor,
            self.position,
            axis=0,
        )
        new_position = (self.position + 1) % self.size
        return NoveltyArchive(
            archive=new_archive, size=self.size, position=new_position
        )

    @partial(jax.jit, static_argnames=("num_nearest_neighbors",))
    def novelty(
        self,
        descriptors: Descriptor,
        num_nearest_neighbors: int,
    ) -> jnp.ndarray:
        """Compute the novelty of the given descriptors as the average distance
        to the k nearest neighbours in the archive.

        Args:
            descriptors: the descriptors to compute novelty for
            num_nearest_neighbors: k used to compute the k-nearest-neighbours
        Returns:
            the novelty of each descriptor in descriptors.
        """

        # Compute all distances with archive content
        def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
            return jnp.sqrt(jnp.sum(jnp.square(x - y)))

        distances = jax.vmap(
            jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
        )(descriptors, self.archive)

        # Filter distance with empty slot of archive
        indices = jnp.arange(0, self.size, step=1) < self.position + 1
        distances = jax.vmap(lambda distance: jnp.where(indices, distance, jnp.inf))(
            distances
        )

        # Find k nearest neighbours
        _, indices = jax.lax.top_k(-distances, num_nearest_neighbors)

        # Compute novelty as average distance with k neirest neirghbours
        distances = jnp.where(distances == jnp.inf, jnp.nan, distances)
        novelty = jnp.nanmean(jnp.take_along_axis(distances, indices, axis=1), axis=1)
        return novelty
update(self, descriptor)

Update the content of the novelty archive with newly generated descriptor.

Parameters:
  • descriptor (Descriptor) – new descriptor generated by MAP-Elites-ES

Returns:
  • NoveltyArchive – The updated NoveltyArchive

Source code in qdax/core/emitters/mees_emitter.py
@jax.jit
def update(
    self,
    descriptor: Descriptor,
) -> NoveltyArchive:
    """Update the content of the novelty archive with newly generated descriptor.

    Args:
        descriptor: new descriptor generated by MAP-Elites-ES
    Returns:
        The updated NoveltyArchive
    """

    new_archive = jax.lax.dynamic_update_slice_in_dim(
        self.archive,
        descriptor,
        self.position,
        axis=0,
    )
    new_position = (self.position + 1) % self.size
    return NoveltyArchive(
        archive=new_archive, size=self.size, position=new_position
    )
novelty(self, descriptors, num_nearest_neighbors)

Compute the novelty of the given descriptors as the average distance to the k nearest neighbours in the archive.

Parameters:
  • descriptors (Descriptor) – the descriptors to compute novelty for

  • num_nearest_neighbors (int) – k used to compute the k-nearest-neighbours

Returns:
  • jnp.ndarray – the novelty of each descriptor in descriptors.

Source code in qdax/core/emitters/mees_emitter.py
@partial(jax.jit, static_argnames=("num_nearest_neighbors",))
def novelty(
    self,
    descriptors: Descriptor,
    num_nearest_neighbors: int,
) -> jnp.ndarray:
    """Compute the novelty of the given descriptors as the average distance
    to the k nearest neighbours in the archive.

    Args:
        descriptors: the descriptors to compute novelty for
        num_nearest_neighbors: k used to compute the k-nearest-neighbours
    Returns:
        the novelty of each descriptor in descriptors.
    """

    # Compute all distances with archive content
    def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
        return jnp.sqrt(jnp.sum(jnp.square(x - y)))

    distances = jax.vmap(
        jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
    )(descriptors, self.archive)

    # Filter distance with empty slot of archive
    indices = jnp.arange(0, self.size, step=1) < self.position + 1
    distances = jax.vmap(lambda distance: jnp.where(indices, distance, jnp.inf))(
        distances
    )

    # Find k nearest neighbours
    _, indices = jax.lax.top_k(-distances, num_nearest_neighbors)

    # Compute novelty as average distance with k neirest neirghbours
    distances = jnp.where(distances == jnp.inf, jnp.nan, distances)
    novelty = jnp.nanmean(jnp.take_along_axis(distances, indices, axis=1), axis=1)
    return novelty
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/mees_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

MEESConfig dataclass

Configuration for the MAP-Elites-ES emitter.

Parameters:
  • sample_number (int) – num of samples for gradient estimate

  • sample_sigma (float) – std to sample the samples for gradient estimate

  • sample_mirror (bool) – if True, use mirroring sampling

  • sample_rank_norm (bool) – if True, use normalisation

  • num_optimizer_steps (int) – frequency of archive-sampling

  • adam_optimizer (bool) – if True, use ADAM, if False, use SGD learning_rate

  • l2_coefficient (float) – coefficient for regularisation novelty_nearest_neighbors

  • last_updated_size (int) – number of last updated indiv used to choose parents from repertoire

  • exploit_num_cell_sample (int) – number of highest-performing cells from which to choose parents, when using exploit

  • explore_num_cell_sample (int) – number of most-novel cells from which to choose parents, when using explore

  • use_explore (bool) – if False, use only fitness gradient

  • use_exploit (bool) – if False, use only novelty gradient

Source code in qdax/core/emitters/mees_emitter.py
@dataclass
class MEESConfig:
    """Configuration for the MAP-Elites-ES emitter.

    Args:
        sample_number: num of samples for gradient estimate
        sample_sigma: std to sample the samples for gradient estimate
        sample_mirror: if True, use mirroring sampling
        sample_rank_norm: if True, use normalisation
        num_optimizer_steps: frequency of archive-sampling
        adam_optimizer: if True, use ADAM, if False, use SGD
            learning_rate
        l2_coefficient: coefficient for regularisation
            novelty_nearest_neighbors
        last_updated_size: number of last updated indiv used to
            choose parents from repertoire
        exploit_num_cell_sample: number of highest-performing cells
            from which to choose parents, when using exploit
        explore_num_cell_sample: number of most-novel cells from
            which to choose parents, when using explore
        use_explore: if False, use only fitness gradient
        use_exploit: if False, use only novelty gradient
    """

    sample_number: int = 1000
    sample_sigma: float = 0.02
    sample_mirror: bool = True
    sample_rank_norm: bool = True
    num_optimizer_steps: int = 10
    adam_optimizer: bool = True
    learning_rate: float = 0.01
    l2_coefficient: float = 0.02
    novelty_nearest_neighbors: int = 10
    last_updated_size: int = 5
    exploit_num_cell_sample: int = 2
    explore_num_cell_sample: int = 5
    use_explore: bool = True
    use_exploit: bool = True

MEESEmitterState (EmitterState) dataclass

Emitter State for the MAP-Elites-ES emitter.

Parameters:
  • initial_optimizer_state (optax.OptState) – stored to re-initialise when sampling new parent

  • optimizer_state (optax.OptState) – current optimizer state

  • offspring (Genotype) – offspring generated through gradient estimate

  • generation_count (int) – generation counter used to update the novelty archive

  • novelty_archive (NoveltyArchive) – used to compute novelty for explore

  • last_updated_genotypes (Genotype) – used to choose parents from repertoire

  • last_updated_fitnesses (Fitness) – used to choose parents from repertoire

  • last_updated_position (jnp.ndarray) – used to choose parents from repertoire

  • random_key (RNGKey) – key to handle stochastic operations

Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitterState(EmitterState):
    """Emitter State for the MAP-Elites-ES emitter.

    Args:
        initial_optimizer_state: stored to re-initialise when sampling new parent
        optimizer_state: current optimizer state
        offspring: offspring generated through gradient estimate
        generation_count: generation counter used to update the novelty archive
        novelty_archive: used to compute novelty for explore
        last_updated_genotypes: used to choose parents from repertoire
        last_updated_fitnesses: used to choose parents from repertoire
        last_updated_position: used to choose parents from repertoire
        random_key: key to handle stochastic operations
    """

    initial_optimizer_state: optax.OptState
    optimizer_state: optax.OptState
    offspring: Genotype
    generation_count: int
    novelty_archive: NoveltyArchive
    last_updated_genotypes: Genotype
    last_updated_fitnesses: Fitness
    last_updated_position: jnp.ndarray
    random_key: RNGKey
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/mees_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

MEESEmitter (Emitter)

Emitter reproducing the MAP-Elites-ES algorithm from "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al: https://dl.acm.org/doi/pdf/10.1145/3377930.3390217

One can choose between the three variants by setting use_explore and use_exploit: ME-ES exploit-explore: use_exploit=True and use_explore=True Alternates between num_optimizer_steps of fitness gradients and num_optimizer_steps of novelty gradients, resample parent from the archive every num_optimizer_steps steps. ME-ES exploit: use_exploit=True and use_explore=False Only uses fitness gradient, no novelty gradients, but resample parent from the archive every num_optimizer_steps steps. ME-ES explore: use_exploit=False and use_explore=True Only uses novelty gradient, no fitness gradients, but resample parent from the archive every num_optimizer_steps steps.

Source code in qdax/core/emitters/mees_emitter.py
class MEESEmitter(Emitter):
    """
    Emitter reproducing the MAP-Elites-ES algorithm from
    "Scaling MAP-Elites to Deep Neuroevolution" by Colas et al:
    https://dl.acm.org/doi/pdf/10.1145/3377930.3390217

    One can choose between the three variants by setting use_explore and use_exploit:
        ME-ES exploit-explore: use_exploit=True and use_explore=True
            Alternates between num_optimizer_steps of fitness gradients and
            num_optimizer_steps of novelty gradients, resample parent from the archive
            every num_optimizer_steps steps.
        ME-ES exploit: use_exploit=True and use_explore=False
            Only uses fitness gradient, no novelty gradients, but resample parent from
            the archive every num_optimizer_steps steps.
        ME-ES explore: use_exploit=False and use_explore=True
            Only uses novelty gradient, no fitness gradients, but resample parent from
            the archive every num_optimizer_steps steps.
    """

    def __init__(
        self,
        config: MEESConfig,
        total_generations: int,
        scoring_fn: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
        ],
        num_descriptors: int,
    ) -> None:
        """Initialise the MAP-Elites-ES emitter.
        WARNING: total_generations is required to build the novelty archive.

        Args:
            config: algorithm config
            scoring_fn: used to evaluate the samples for the gradient estimate.
            total_generations: total number of generations for which the
                emitter will run, allow to initialise the novelty archive.
            num_descriptors: dimension of the descriptors, used to initialise
                the empty novelty archive.
        """
        self._config = config
        self._scoring_fn = scoring_fn
        self._total_generations = total_generations
        self._num_descriptors = num_descriptors

        # Initialise optimizer
        if self._config.adam_optimizer:
            self._optimizer = optax.adam(learning_rate=config.learning_rate)
        else:
            self._optimizer = optax.sgd(learning_rate=config.learning_rate)

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return 1

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[MEESEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the MEESEmitter, a new random key.
        """
        # Initialisation requires one initial genotype
        if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
            init_genotypes = jax.tree_util.tree_map(
                lambda x: x[0],
                init_genotypes,
            )

        # Initialise optimizer
        initial_optimizer_state = self._optimizer.init(init_genotypes)

        # Create empty Novelty archive
        if self._config.use_explore:
            novelty_archive = NoveltyArchive.init(
                self._total_generations, self._num_descriptors
            )
        else:
            novelty_archive = NoveltyArchive.init(
                self._config.novelty_nearest_neighbors, self._num_descriptors
            )

        # Create empty updated genotypes and fitness
        last_updated_genotypes = jax.tree_util.tree_map(
            lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
            init_genotypes,
        )
        last_updated_fitnesses = -jnp.inf * jnp.ones(
            shape=self._config.last_updated_size
        )

        return (
            MEESEmitterState(
                initial_optimizer_state=initial_optimizer_state,
                optimizer_state=initial_optimizer_state,
                offspring=init_genotypes,
                generation_count=0,
                novelty_archive=novelty_archive,
                last_updated_genotypes=last_updated_genotypes,
                last_updated_fitnesses=last_updated_fitnesses,
                last_updated_position=0,
                random_key=random_key,
            ),
            random_key,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: MEESEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Return the offspring generated through gradient update.

        Params:
            repertoire: the MAP-Elites repertoire to sample from
            emitter_state
            random_key: a jax PRNG random key

        Returns:
            a new gradient offspring
            a new jax PRNG key
        """

        return emitter_state.offspring, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _sample_exploit(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Sample half of the time uniformly from the exploit_num_cell_sample
        highest-performing cells of the repertoire and half of the time uniformly
        from the exploit_num_cell_sample highest-performing cells among the
        last updated cells.

        Args:
            emitter_state: current emitter_state
            repertoire: the current repertoire
            random_key: a jax PRNG random key

        Returns:
            samples: a genotype sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        def _sample(
            random_key: RNGKey,
            genotypes: Genotype,
            fitnesses: Fitness,
        ) -> Tuple[Genotype, RNGKey]:
            """Sample uniformly from the 2 highest fitness cells."""

            max_fitnesses, _ = jax.lax.top_k(
                fitnesses, self._config.exploit_num_cell_sample
            )
            min_fitness = jnp.nanmin(
                jnp.where(max_fitnesses > -jnp.inf, max_fitnesses, jnp.inf)
            )
            genotypes_empty = fitnesses < min_fitness
            p = (1.0 - genotypes_empty) / jnp.sum(1.0 - genotypes_empty)
            random_key, subkey = jax.random.split(random_key)
            samples = jax.tree_map(
                lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
                genotypes,
            )
            return samples, random_key

        random_key, subkey = jax.random.split(random_key)

        # Sample p uniformly
        p = jax.random.uniform(subkey)

        # Depending on the value of p, use one of the two sampling options
        repertoire_sample = partial(
            _sample, genotypes=repertoire.genotypes, fitnesses=repertoire.fitnesses
        )
        last_updated_sample = partial(
            _sample,
            genotypes=emitter_state.last_updated_genotypes,
            fitnesses=emitter_state.last_updated_fitnesses,
        )
        samples, random_key = jax.lax.cond(
            p < 0.5,
            repertoire_sample,
            last_updated_sample,
            random_key,
        )

        return samples, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _sample_explore(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Sample uniformly from the explore_num_cell_sample most-novel genotypes.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            random_key: a jax PRNG random key

        Returns:
            samples: a genotype sampled in the repertoire
            random_key: an updated jax PRNG random key
        """

        # Compute the novelty of all indivs in the archive
        novelties = emitter_state.novelty_archive.novelty(
            repertoire.descriptors, self._config.novelty_nearest_neighbors
        )
        novelties = jnp.where(repertoire.fitnesses > -jnp.inf, novelties, -jnp.inf)

        # Sample uniformly for the explore_num_cell_sample most novel cells
        max_novelties, _ = jax.lax.top_k(
            novelties, self._config.explore_num_cell_sample
        )
        min_novelty = jnp.nanmin(
            jnp.where(max_novelties > -jnp.inf, max_novelties, jnp.inf)
        )
        repertoire_empty = novelties < min_novelty
        p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)
        random_key, subkey = jax.random.split(random_key)
        samples = jax.tree_map(
            lambda x: jax.random.choice(subkey, x, shape=(1,), p=p),
            repertoire.genotypes,
        )

        return samples, random_key

    @partial(
        jax.jit,
        static_argnames=("self", "scores_fn"),
    )
    def _es_emitter(
        self,
        parent: Genotype,
        optimizer_state: optax.OptState,
        random_key: RNGKey,
        scores_fn: Callable[[Fitness, Descriptor], jnp.ndarray],
    ) -> Tuple[Genotype, optax.OptState, RNGKey]:
        """Main es component, given a parent and a way to infer the score from
        the fitnesses and descriptors fo its es-samples, return its
        approximated-gradient-generated offspring.

        Args:
            parent: the considered parent.
            scores_fn: a function to infer the score of its es-samples from
                their fitness and descriptors.
            random_key

        Returns:
            The approximated-gradients-generated offspring and a new random_key.
        """

        random_key, subkey = jax.random.split(random_key)

        # Sampling mirror noise
        total_sample_number = self._config.sample_number
        if self._config.sample_mirror:

            sample_number = total_sample_number // 2
            half_sample_noise = jax.tree_util.tree_map(
                lambda x: jax.random.normal(
                    key=subkey,
                    shape=jnp.repeat(x, sample_number, axis=0).shape,
                ),
                parent,
            )
            sample_noise = jax.tree_util.tree_map(
                lambda x: jnp.concatenate(
                    [jnp.expand_dims(x, axis=1), jnp.expand_dims(-x, axis=1)], axis=1
                ).reshape(jnp.repeat(x, 2, axis=0).shape),
                half_sample_noise,
            )
            gradient_noise = half_sample_noise

        # Sampling non-mirror noise
        else:
            sample_number = total_sample_number
            sample_noise = jax.tree_map(
                lambda x: jax.random.normal(
                    key=subkey,
                    shape=jnp.repeat(x, sample_number, axis=0).shape,
                ),
                parent,
            )
            gradient_noise = sample_noise

        # Applying noise
        samples = jax.tree_map(
            lambda x: jnp.repeat(x, total_sample_number, axis=0),
            parent,
        )
        samples = jax.tree_map(
            lambda mean, noise: mean + self._config.sample_sigma * noise,
            samples,
            sample_noise,
        )

        # Evaluating samples
        fitnesses, descriptors, extra_scores, random_key = self._scoring_fn(
            samples, random_key
        )

        # Computing rank, with or without normalisation
        scores = scores_fn(fitnesses, descriptors)

        if self._config.sample_rank_norm:
            ranking_indices = jnp.argsort(scores, axis=0)
            ranks = jnp.argsort(ranking_indices, axis=0)
            ranks = (ranks / (total_sample_number - 1)) - 0.5

        else:
            ranks = scores

        # Reshaping rank to match shape of genotype_noise
        if self._config.sample_mirror:
            ranks = jnp.reshape(ranks, (sample_number, 2))
            ranks = jnp.apply_along_axis(lambda rank: rank[0] - rank[1], 1, ranks)
        ranks = jax.tree_map(
            lambda x: jnp.reshape(
                jnp.repeat(ranks.ravel(), x[0].ravel().shape[0], axis=0), x.shape
            ),
            gradient_noise,
        )

        # Computing the gradients
        gradient = jax.tree_map(
            lambda noise, rank: jnp.multiply(noise, rank),
            gradient_noise,
            ranks,
        )
        gradient = jax.tree_map(
            lambda x: jnp.reshape(x, (sample_number, -1)),
            gradient,
        )
        gradient = jax.tree_map(
            lambda g, p: jnp.reshape(
                -jnp.sum(g, axis=0) / (total_sample_number * self._config.sample_sigma),
                p.shape,
            ),
            gradient,
            parent,
        )

        # Adding regularisation
        gradient = jax.tree_map(
            lambda g, p: g + self._config.l2_coefficient * p,
            gradient,
            parent,
        )

        # Applying gradients
        (offspring_update, optimizer_state) = self._optimizer.update(
            gradient, optimizer_state
        )
        offspring = optax.apply_updates(parent, offspring_update)

        return offspring, optimizer_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def _buffers_update(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
    ) -> MEESEmitterState:
        """Update the different buffers and archives in the emitter
        state to generate the offspring for the next generation.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.

        Returns:
            The modified emitter state.
        """

        # Updating novelty archive
        novelty_archive = emitter_state.novelty_archive.update(descriptors)

        # Check if genotype from previous iteration has been added to the grid
        indice = get_cells_indices(descriptors, repertoire.centroids)
        added_genotype = jnp.all(
            jnp.asarray(
                jax.tree_util.tree_leaves(
                    jax.tree_util.tree_map(
                        lambda new_gen, rep_gen: jnp.all(
                            jnp.equal(
                                jnp.ravel(new_gen), jnp.ravel(rep_gen.at[indice].get())
                            ),
                            axis=0,
                        ),
                        genotypes,
                        repertoire.genotypes,
                    ),
                )
            ),
            axis=0,
        )

        # Update last_updated buffers
        last_updated_position = jnp.where(
            added_genotype,
            emitter_state.last_updated_position,
            self._config.last_updated_size + 1,
        )
        last_updated_fitnesses = emitter_state.last_updated_fitnesses
        last_updated_fitnesses = last_updated_fitnesses.at[last_updated_position].set(
            fitnesses[0]
        )
        last_updated_genotypes = jax.tree_map(
            lambda last_gen, gen: last_gen.at[
                jnp.expand_dims(last_updated_position, axis=0)
            ].set(gen),
            emitter_state.last_updated_genotypes,
            genotypes,
        )
        last_updated_position = (
            emitter_state.last_updated_position + added_genotype
        ) % self._config.last_updated_size

        # Return new emitter_state
        return emitter_state.replace(  # type: ignore
            novelty_archive=novelty_archive,
            last_updated_genotypes=last_updated_genotypes,
            last_updated_fitnesses=last_updated_fitnesses,
            last_updated_position=last_updated_position,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: MEESEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> MEESEmitterState:
        """Generate the gradient offspring for the next emitter call. Also
        update the novelty archive and generation count from current call.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """

        assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
            "ERROR: MAP-Elites-ES generates 1 offspring per generation, "
            + "batch_size should be 1, the inputed batch has size:"
            + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
        )

        # Update all the buffers and archives of the emitter_state
        emitter_state = self._buffers_update(
            emitter_state, repertoire, genotypes, fitnesses, descriptors
        )

        # Use new or previous parents and exploitation or exploration
        generation_count = emitter_state.generation_count
        sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
        use_exploration = (
            self._config.use_explore and not self._config.use_exploit
        ) or (
            self._config.use_explore
            and self._config.use_exploit
            and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
        )

        # Select parent and optimizer_state
        parent, random_key = jax.lax.cond(
            sample_new_parent,
            lambda emitter_state, repertoire, random_key: jax.lax.cond(
                use_exploration,
                self._sample_explore,
                self._sample_exploit,
                emitter_state,
                repertoire,
                random_key,
            ),
            lambda emitter_state, repertoire, random_key: (
                emitter_state.offspring,
                random_key,
            ),
            emitter_state,
            repertoire,
            emitter_state.random_key,
        )
        optimizer_state = jax.lax.cond(
            sample_new_parent,
            lambda _unused: emitter_state.initial_optimizer_state,
            lambda _unused: emitter_state.optimizer_state,
            (),
        )

        # Define scores for es process
        def exploration_exploitation_scores(
            fitnesses: Fitness, descriptors: Descriptor
        ) -> jnp.ndarray:
            scores = jax.lax.cond(
                use_exploration,
                lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
                    descriptors, self._config.novelty_nearest_neighbors
                ),
                lambda fitnesses, descriptors: fitnesses,
                fitnesses,
                descriptors,
            )
            return scores

        # Run es process
        offspring, optimizer_state, random_key = self._es_emitter(
            parent=parent,
            optimizer_state=optimizer_state,
            random_key=random_key,
            scores_fn=exploration_exploitation_scores,
        )

        return emitter_state.replace(  # type: ignore
            optimizer_state=optimizer_state,
            offspring=offspring,
            generation_count=generation_count + 1,
            random_key=random_key,
        )
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

__init__(self, config, total_generations, scoring_fn, num_descriptors) special

Initialise the MAP-Elites-ES emitter. WARNING: total_generations is required to build the novelty archive.

Parameters:
  • config (MEESConfig) – algorithm config

  • scoring_fn (Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]) – used to evaluate the samples for the gradient estimate.

  • total_generations (int) – total number of generations for which the emitter will run, allow to initialise the novelty archive.

  • num_descriptors (int) – dimension of the descriptors, used to initialise the empty novelty archive.

Source code in qdax/core/emitters/mees_emitter.py
def __init__(
    self,
    config: MEESConfig,
    total_generations: int,
    scoring_fn: Callable[
        [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]
    ],
    num_descriptors: int,
) -> None:
    """Initialise the MAP-Elites-ES emitter.
    WARNING: total_generations is required to build the novelty archive.

    Args:
        config: algorithm config
        scoring_fn: used to evaluate the samples for the gradient estimate.
        total_generations: total number of generations for which the
            emitter will run, allow to initialise the novelty archive.
        num_descriptors: dimension of the descriptors, used to initialise
            the empty novelty archive.
    """
    self._config = config
    self._scoring_fn = scoring_fn
    self._total_generations = total_generations
    self._num_descriptors = num_descriptors

    # Initialise optimizer
    if self._config.adam_optimizer:
        self._optimizer = optax.adam(learning_rate=config.learning_rate)
    else:
        self._optimizer = optax.sgd(learning_rate=config.learning_rate)
init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Genotype) – The initial population.

  • random_key (RNGKey) – A random key.

Returns:
  • Tuple[MEESEmitterState, RNGKey] – The initial state of the MEESEmitter, a new random key.

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[MEESEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the MEESEmitter, a new random key.
    """
    # Initialisation requires one initial genotype
    if jax.tree_util.tree_leaves(init_genotypes)[0].shape[0] > 1:
        init_genotypes = jax.tree_util.tree_map(
            lambda x: x[0],
            init_genotypes,
        )

    # Initialise optimizer
    initial_optimizer_state = self._optimizer.init(init_genotypes)

    # Create empty Novelty archive
    if self._config.use_explore:
        novelty_archive = NoveltyArchive.init(
            self._total_generations, self._num_descriptors
        )
    else:
        novelty_archive = NoveltyArchive.init(
            self._config.novelty_nearest_neighbors, self._num_descriptors
        )

    # Create empty updated genotypes and fitness
    last_updated_genotypes = jax.tree_util.tree_map(
        lambda x: jnp.zeros(shape=(self._config.last_updated_size,) + x.shape[1:]),
        init_genotypes,
    )
    last_updated_fitnesses = -jnp.inf * jnp.ones(
        shape=self._config.last_updated_size
    )

    return (
        MEESEmitterState(
            initial_optimizer_state=initial_optimizer_state,
            optimizer_state=initial_optimizer_state,
            offspring=init_genotypes,
            generation_count=0,
            novelty_archive=novelty_archive,
            last_updated_genotypes=last_updated_genotypes,
            last_updated_fitnesses=last_updated_fitnesses,
            last_updated_position=0,
            random_key=random_key,
        ),
        random_key,
    )
emit(self, repertoire, emitter_state, random_key)

Return the offspring generated through gradient update.

Parameters:
  • repertoire (MapElitesRepertoire) – the MAP-Elites repertoire to sample from

  • random_key (RNGKey) – a jax PRNG random key

Returns:
  • Tuple[Genotype, RNGKey] – a new gradient offspring a new jax PRNG key

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: MEESEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Return the offspring generated through gradient update.

    Params:
        repertoire: the MAP-Elites repertoire to sample from
        emitter_state
        random_key: a jax PRNG random key

    Returns:
        a new gradient offspring
        a new jax PRNG key
    """

    return emitter_state.offspring, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Generate the gradient offspring for the next emitter call. Also update the novelty archive and generation count from current call.

Parameters:
  • emitter_state (MEESEmitterState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Genotype) – the genotypes of the batch of emitted offspring.

  • fitnesses (Fitness) – the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) – the descriptors of the emitted offspring.

  • extra_scores (ExtraScores) – a dictionary with other values outputted by the scoring function.

Returns:
  • MEESEmitterState – The modified emitter state.

Source code in qdax/core/emitters/mees_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: MEESEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> MEESEmitterState:
    """Generate the gradient offspring for the next emitter call. Also
    update the novelty archive and generation count from current call.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring.
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: a dictionary with other values outputted by the
            scoring function.

    Returns:
        The modified emitter state.
    """

    assert jax.tree_util.tree_leaves(genotypes)[0].shape[0] == 1, (
        "ERROR: MAP-Elites-ES generates 1 offspring per generation, "
        + "batch_size should be 1, the inputed batch has size:"
        + str(jax.tree_util.tree_leaves(genotypes)[0].shape[0])
    )

    # Update all the buffers and archives of the emitter_state
    emitter_state = self._buffers_update(
        emitter_state, repertoire, genotypes, fitnesses, descriptors
    )

    # Use new or previous parents and exploitation or exploration
    generation_count = emitter_state.generation_count
    sample_new_parent = generation_count % self._config.num_optimizer_steps == 0
    use_exploration = (
        self._config.use_explore and not self._config.use_exploit
    ) or (
        self._config.use_explore
        and self._config.use_exploit
        and ((generation_count // self._config.num_optimizer_steps) % 2 == 0)
    )

    # Select parent and optimizer_state
    parent, random_key = jax.lax.cond(
        sample_new_parent,
        lambda emitter_state, repertoire, random_key: jax.lax.cond(
            use_exploration,
            self._sample_explore,
            self._sample_exploit,
            emitter_state,
            repertoire,
            random_key,
        ),
        lambda emitter_state, repertoire, random_key: (
            emitter_state.offspring,
            random_key,
        ),
        emitter_state,
        repertoire,
        emitter_state.random_key,
    )
    optimizer_state = jax.lax.cond(
        sample_new_parent,
        lambda _unused: emitter_state.initial_optimizer_state,
        lambda _unused: emitter_state.optimizer_state,
        (),
    )

    # Define scores for es process
    def exploration_exploitation_scores(
        fitnesses: Fitness, descriptors: Descriptor
    ) -> jnp.ndarray:
        scores = jax.lax.cond(
            use_exploration,
            lambda fitnesses, descriptors: emitter_state.novelty_archive.novelty(
                descriptors, self._config.novelty_nearest_neighbors
            ),
            lambda fitnesses, descriptors: fitnesses,
            fitnesses,
            descriptors,
        )
        return scores

    # Run es process
    offspring, optimizer_state, random_key = self._es_emitter(
        parent=parent,
        optimizer_state=optimizer_state,
        random_key=random_key,
        scores_fn=exploration_exploitation_scores,
    )

    return emitter_state.replace(  # type: ignore
        optimizer_state=optimizer_state,
        offspring=offspring,
        generation_count=generation_count + 1,
        random_key=random_key,
    )

multi_emitter

MultiEmitterState (EmitterState) dataclass

State of an emitter than use multiple emitters in a parallel manner.

WARNING: this is not the emitter state of Multi-Emitter MAP-Elites.

Parameters:
  • emitter_states (Tuple[qdax.core.emitters.emitter.EmitterState, ...]) – a tuple of emitter states

Source code in qdax/core/emitters/multi_emitter.py
class MultiEmitterState(EmitterState):
    """State of an emitter than use multiple emitters in a parallel manner.

    WARNING: this is not the emitter state of Multi-Emitter MAP-Elites.

    Args:
        emitter_states: a tuple of emitter states
    """

    emitter_states: Tuple[EmitterState, ...]
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/multi_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

MultiEmitter (Emitter)

Emitter that mixes several emitters in parallel.

WARNING: this is not the emitter of Multi-Emitter MAP-Elites.

Source code in qdax/core/emitters/multi_emitter.py
class MultiEmitter(Emitter):
    """Emitter that mixes several emitters in parallel.

    WARNING: this is not the emitter of Multi-Emitter MAP-Elites.
    """

    def __init__(
        self,
        emitters: Tuple[Emitter, ...],
    ):
        self.emitters = emitters
        indexes_separation_batches = self.get_indexes_separation_batches(emitters)
        self.indexes_start_batches = indexes_separation_batches[:-1]
        self.indexes_end_batches = indexes_separation_batches[1:]

    @staticmethod
    def get_indexes_separation_batches(
        emitters: Tuple[Emitter, ...]
    ) -> Tuple[int, ...]:
        """Get the indexes of the separation between batches of each emitter.

        Args:
            emitters: the emitters

        Returns:
            a tuple of tuples of indexes
        """
        indexes_separation_batches = np.cumsum(
            [0] + [emitter.batch_size for emitter in emitters]
        )
        return tuple(indexes_separation_batches)

    def init(
        self, init_genotypes: Optional[Genotype], random_key: RNGKey
    ) -> Tuple[Optional[EmitterState], RNGKey]:
        """
        Initialize the state of the emitter.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state and a random key.
        """

        # prepare keys for each emitter
        random_key, subkey = jax.random.split(random_key)
        subkeys = jax.random.split(subkey, len(self.emitters))

        # init all emitter states - gather them
        emitter_states = []
        for emitter, subkey_emitter in zip(self.emitters, subkeys):
            emitter_state, _ = emitter.init(init_genotypes, subkey_emitter)
            emitter_states.append(emitter_state)

        return MultiEmitterState(tuple(emitter_states)), random_key

    @partial(jax.jit, static_argnames=("self",))
    def emit(
        self,
        repertoire: Optional[Repertoire],
        emitter_state: Optional[MultiEmitterState],
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Emit new population. Use all the sub emitters to emit subpopulation
        and gather them.

        Args:
            repertoire: a repertoire of genotypes.
            emitter_state: the current state of the emitter.
            random_key: key for random operations.

        Returns:
            Offsprings and a new random key.
        """
        assert emitter_state is not None
        assert len(emitter_state.emitter_states) == len(self.emitters)

        # prepare subkeys for each sub emitter
        random_key, subkey = jax.random.split(random_key)
        subkeys = jax.random.split(subkey, len(self.emitters))

        # emit from all emitters and gather offsprings
        all_offsprings = []
        for emitter, sub_emitter_state, subkey_emitter in zip(
            self.emitters,
            emitter_state.emitter_states,
            subkeys,
        ):
            genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter)
            batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0]
            assert batch_size == emitter.batch_size
            all_offsprings.append(genotype)

        # concatenate offsprings together
        offsprings = jax.tree_util.tree_map(
            lambda *x: jnp.concatenate(x, axis=0), *all_offsprings
        )
        return offsprings, random_key

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: Optional[MultiEmitterState],
        repertoire: Optional[Repertoire] = None,
        genotypes: Optional[Genotype] = None,
        fitnesses: Optional[Fitness] = None,
        descriptors: Optional[Descriptor] = None,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[MultiEmitterState]:
        """Update emitter state by updating all sub emitter states.

        Args:
            emitter_state: current emitter state.
            repertoire: current repertoire of genotypes. Defaults to None.
            genotypes: proposed genotypes. Defaults to None.
            fitnesses: associated fitnesses. Defaults to None.
            descriptors: associated descriptors. Defaults to None.
            extra_scores: associated extra_scores. Defaults to None.

        Returns:
            The updated global emitter state.
        """
        if emitter_state is None:
            return None

        # update all the sub emitter states
        emitter_states = []

        def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree:
            return jax.tree_util.tree_map(lambda x: x[start:end], pytree)

        for emitter, sub_emitter_state, index_start, index_end in zip(
            self.emitters,
            emitter_state.emitter_states,
            self.indexes_start_batches,
            self.indexes_end_batches,
        ):
            # update with all genotypes, fitnesses, etc...
            if emitter.use_all_data:
                new_sub_emitter_state = emitter.state_update(
                    sub_emitter_state,
                    repertoire,
                    genotypes,
                    fitnesses,
                    descriptors,
                    extra_scores,
                )
                emitter_states.append(new_sub_emitter_state)
            # update only with the data of the emitted genotypes
            else:
                # extract relevant data
                sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree_util.tree_map(
                    partial(_get_sub_pytree, start=index_start, end=index_end),
                    (
                        genotypes,
                        fitnesses,
                        descriptors,
                        extra_scores,
                    ),
                )
                # update only with the relevant data
                new_sub_emitter_state = emitter.state_update(
                    sub_emitter_state,
                    repertoire,
                    sub_gen,
                    sub_fit,
                    sub_desc,
                    sub_extra_scores,
                )
                emitter_states.append(new_sub_emitter_state)

        # return the update global emitter state
        return MultiEmitterState(tuple(emitter_states))

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return sum(emitter.batch_size for emitter in self.emitters)
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

get_indexes_separation_batches(emitters) staticmethod

Get the indexes of the separation between batches of each emitter.

Parameters:
  • emitters (Tuple[qdax.core.emitters.emitter.Emitter, ...]) – the emitters

Returns:
  • Tuple[int, ...] – a tuple of tuples of indexes

Source code in qdax/core/emitters/multi_emitter.py
@staticmethod
def get_indexes_separation_batches(
    emitters: Tuple[Emitter, ...]
) -> Tuple[int, ...]:
    """Get the indexes of the separation between batches of each emitter.

    Args:
        emitters: the emitters

    Returns:
        a tuple of tuples of indexes
    """
    indexes_separation_batches = np.cumsum(
        [0] + [emitter.batch_size for emitter in emitters]
    )
    return tuple(indexes_separation_batches)
init(self, init_genotypes, random_key)

Initialize the state of the emitter.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The genotypes of the initial population.

  • random_key (Array) – a random key to handle stochastic operations.

Returns:
  • Tuple[Optional[qdax.core.emitters.emitter.EmitterState], jax.Array] – The initial emitter state and a random key.

Source code in qdax/core/emitters/multi_emitter.py
def init(
    self, init_genotypes: Optional[Genotype], random_key: RNGKey
) -> Tuple[Optional[EmitterState], RNGKey]:
    """
    Initialize the state of the emitter.

    Args:
        init_genotypes: The genotypes of the initial population.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial emitter state and a random key.
    """

    # prepare keys for each emitter
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey, len(self.emitters))

    # init all emitter states - gather them
    emitter_states = []
    for emitter, subkey_emitter in zip(self.emitters, subkeys):
        emitter_state, _ = emitter.init(init_genotypes, subkey_emitter)
        emitter_states.append(emitter_state)

    return MultiEmitterState(tuple(emitter_states)), random_key
emit(self, repertoire, emitter_state, random_key)

Emit new population. Use all the sub emitters to emit subpopulation and gather them.

Parameters:
  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – a repertoire of genotypes.

  • emitter_state (Optional[qdax.core.emitters.multi_emitter.MultiEmitterState]) – the current state of the emitter.

  • random_key (Array) – key for random operations.

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – Offsprings and a new random key.

Source code in qdax/core/emitters/multi_emitter.py
@partial(jax.jit, static_argnames=("self",))
def emit(
    self,
    repertoire: Optional[Repertoire],
    emitter_state: Optional[MultiEmitterState],
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Emit new population. Use all the sub emitters to emit subpopulation
    and gather them.

    Args:
        repertoire: a repertoire of genotypes.
        emitter_state: the current state of the emitter.
        random_key: key for random operations.

    Returns:
        Offsprings and a new random key.
    """
    assert emitter_state is not None
    assert len(emitter_state.emitter_states) == len(self.emitters)

    # prepare subkeys for each sub emitter
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey, len(self.emitters))

    # emit from all emitters and gather offsprings
    all_offsprings = []
    for emitter, sub_emitter_state, subkey_emitter in zip(
        self.emitters,
        emitter_state.emitter_states,
        subkeys,
    ):
        genotype, _ = emitter.emit(repertoire, sub_emitter_state, subkey_emitter)
        batch_size = jax.tree_util.tree_leaves(genotype)[0].shape[0]
        assert batch_size == emitter.batch_size
        all_offsprings.append(genotype)

    # concatenate offsprings together
    offsprings = jax.tree_util.tree_map(
        lambda *x: jnp.concatenate(x, axis=0), *all_offsprings
    )
    return offsprings, random_key
state_update(self, emitter_state, repertoire=None, genotypes=None, fitnesses=None, descriptors=None, extra_scores=None)

Update emitter state by updating all sub emitter states.

Parameters:
  • emitter_state (Optional[qdax.core.emitters.multi_emitter.MultiEmitterState]) – current emitter state.

  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – current repertoire of genotypes. Defaults to None.

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – proposed genotypes. Defaults to None.

  • fitnesses (Optional[jax.Array]) – associated fitnesses. Defaults to None.

  • descriptors (Optional[jax.Array]) – associated descriptors. Defaults to None.

  • extra_scores (Optional[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]) – associated extra_scores. Defaults to None.

Returns:
  • Optional[qdax.core.emitters.multi_emitter.MultiEmitterState] – The updated global emitter state.

Source code in qdax/core/emitters/multi_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
    self,
    emitter_state: Optional[MultiEmitterState],
    repertoire: Optional[Repertoire] = None,
    genotypes: Optional[Genotype] = None,
    fitnesses: Optional[Fitness] = None,
    descriptors: Optional[Descriptor] = None,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[MultiEmitterState]:
    """Update emitter state by updating all sub emitter states.

    Args:
        emitter_state: current emitter state.
        repertoire: current repertoire of genotypes. Defaults to None.
        genotypes: proposed genotypes. Defaults to None.
        fitnesses: associated fitnesses. Defaults to None.
        descriptors: associated descriptors. Defaults to None.
        extra_scores: associated extra_scores. Defaults to None.

    Returns:
        The updated global emitter state.
    """
    if emitter_state is None:
        return None

    # update all the sub emitter states
    emitter_states = []

    def _get_sub_pytree(pytree: ArrayTree, start: int, end: int) -> ArrayTree:
        return jax.tree_util.tree_map(lambda x: x[start:end], pytree)

    for emitter, sub_emitter_state, index_start, index_end in zip(
        self.emitters,
        emitter_state.emitter_states,
        self.indexes_start_batches,
        self.indexes_end_batches,
    ):
        # update with all genotypes, fitnesses, etc...
        if emitter.use_all_data:
            new_sub_emitter_state = emitter.state_update(
                sub_emitter_state,
                repertoire,
                genotypes,
                fitnesses,
                descriptors,
                extra_scores,
            )
            emitter_states.append(new_sub_emitter_state)
        # update only with the data of the emitted genotypes
        else:
            # extract relevant data
            sub_gen, sub_fit, sub_desc, sub_extra_scores = jax.tree_util.tree_map(
                partial(_get_sub_pytree, start=index_start, end=index_end),
                (
                    genotypes,
                    fitnesses,
                    descriptors,
                    extra_scores,
                ),
            )
            # update only with the relevant data
            new_sub_emitter_state = emitter.state_update(
                sub_emitter_state,
                repertoire,
                sub_gen,
                sub_fit,
                sub_desc,
                sub_extra_scores,
            )
            emitter_states.append(new_sub_emitter_state)

    # return the update global emitter state
    return MultiEmitterState(tuple(emitter_states))

mutation_operators

File defining mutation and crossover functions.

polynomial_mutation(x, random_key, proportion_to_mutate, eta, minval, maxval)

Polynomial mutation over several genotypes

Parameters:
  • x (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – array of genotypes to transform (real values only)

  • random_key (Array) – RNG key for reproducibility. Assumed to be of shape (batch_size, genotype_dim)

  • proportion_to_mutate (float) – proportion of variables to mutate in each genotype (must be in [0, 1]).

  • eta (float) – scaling parameter, the larger the more spread the new values will be.

  • minval (float) – minimum value to clip the genotypes.

  • maxval (float) – maximum value to clip the genotypes.

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – New genotypes - same shape as input and a new RNG key

Source code in qdax/core/emitters/mutation_operators.py
def polynomial_mutation(
    x: Genotype,
    random_key: RNGKey,
    proportion_to_mutate: float,
    eta: float,
    minval: float,
    maxval: float,
) -> Tuple[Genotype, RNGKey]:
    """
    Polynomial mutation over several genotypes

    Parameters:
        x: array of genotypes to transform (real values only)
        random_key: RNG key for reproducibility.
            Assumed to be of shape (batch_size, genotype_dim)
        proportion_to_mutate (float): proportion of variables to mutate in
            each genotype (must be in [0, 1]).
        eta: scaling parameter, the larger the more spread the new
            values will be.
        minval: minimum value to clip the genotypes.
        maxval: maximum value to clip the genotypes.

    Returns:
        New genotypes - same shape as input and a new RNG key
    """
    random_key, subkey = jax.random.split(random_key)
    batch_size = jax.tree_util.tree_leaves(x)[0].shape[0]
    mutation_key = jax.random.split(subkey, num=batch_size)
    mutation_fn = partial(
        _polynomial_mutation,
        proportion_to_mutate=proportion_to_mutate,
        eta=eta,
        minval=minval,
        maxval=maxval,
    )
    mutation_fn = jax.vmap(mutation_fn)
    x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x)
    return x, random_key

polynomial_crossover(x1, x2, random_key, proportion_var_to_change)

Crossover over a set of pairs of genotypes.

Batched version of _simple_crossover_function x1 and x2 should have the same shape In this function we assume x1 shape and x2 shape to be (batch_size, genotype_dim)

Parameters:
  • x1 (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – first batch of genotypes

  • x2 (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – second batch of genotypes

  • random_key (Array) – RNG key for reproducibility

  • proportion_var_to_change (float) – proportion of variables to exchange between genotypes (must be [0, 1])

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – New genotypes and a new RNG key

Source code in qdax/core/emitters/mutation_operators.py
def polynomial_crossover(
    x1: Genotype,
    x2: Genotype,
    random_key: RNGKey,
    proportion_var_to_change: float,
) -> Tuple[Genotype, RNGKey]:
    """
    Crossover over a set of pairs of genotypes.

    Batched version of _simple_crossover_function
    x1 and x2 should have the same shape
    In this function we assume x1 shape and x2 shape to be
    (batch_size, genotype_dim)

    Parameters:
        x1: first batch of genotypes
        x2: second batch of genotypes
        random_key: RNG key for reproducibility
        proportion_var_to_change: proportion of variables to exchange
            between genotypes (must be [0, 1])

    Returns:
        New genotypes and a new RNG key
    """

    random_key, subkey = jax.random.split(random_key)
    batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0]
    crossover_keys = jax.random.split(subkey, num=batch_size)
    crossover_fn = partial(
        _polynomial_crossover,
        proportion_var_to_change=proportion_var_to_change,
    )
    crossover_fn = jax.vmap(crossover_fn)
    # TODO: check that key usage is correct
    x = jax.tree_util.tree_map(
        lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2
    )
    return x, random_key

isoline_variation(x1, x2, random_key, iso_sigma, line_sigma, minval=None, maxval=None)

Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes

Parameters:
  • x1 (Genotypes) – first batch of genotypes

  • x2 (Genotypes) – second batch of genotypes

  • random_key (RNGKey) – RNG key for reproducibility

  • iso_sigma (float) – spread parameter (noise)

  • line_sigma (float) – line parameter (direction of the new genotype)

  • minval (float, Optional) – minimum value to clip the genotypes

  • maxval (float, Optional) – maximum value to clip the genotypes

Returns:
  • x (Genotypes) – new genotypes random_key (RNGKey): new RNG key

[1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and Evolutionary Computation Conference. 2018.

Source code in qdax/core/emitters/mutation_operators.py
def isoline_variation(
    x1: Genotype,
    x2: Genotype,
    random_key: RNGKey,
    iso_sigma: float,
    line_sigma: float,
    minval: Optional[float] = None,
    maxval: Optional[float] = None,
) -> Tuple[Genotype, RNGKey]:
    """
    Iso+Line-DD Variation Operator [1] over a set of pairs of genotypes

    Parameters:
        x1 (Genotypes): first batch of genotypes
        x2 (Genotypes): second batch of genotypes
        random_key (RNGKey): RNG key for reproducibility
        iso_sigma (float): spread parameter (noise)
        line_sigma (float): line parameter (direction of the new genotype)
        minval (float, Optional): minimum value to clip the genotypes
        maxval (float, Optional): maximum value to clip the genotypes

    Returns:
        x (Genotypes): new genotypes
        random_key (RNGKey): new RNG key

    [1] Vassiliades, Vassilis, and Jean-Baptiste Mouret. "Discovering the elite
    hypervolume by leveraging interspecies correlation." Proceedings of the Genetic and
    Evolutionary Computation Conference. 2018.
    """

    # Computing line_noise
    random_key, key_line_noise = jax.random.split(random_key)
    batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0]
    line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma

    def _variation_fn(
        x1: jnp.ndarray, x2: jnp.ndarray, random_key: RNGKey
    ) -> jnp.ndarray:
        iso_noise = jax.random.normal(random_key, shape=x1.shape) * iso_sigma
        x = (x1 + iso_noise) + jax.vmap(jnp.multiply)((x2 - x1), line_noise)

        # Back in bounds if necessary (floating point issues)
        if (minval is not None) or (maxval is not None):
            x = jnp.clip(x, minval, maxval)
        return x

    # create a tree with random keys
    nb_leaves = len(jax.tree_util.tree_leaves(x1))
    random_key, subkey = jax.random.split(random_key)
    subkeys = jax.random.split(subkey, num=nb_leaves)
    keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys)

    # apply isolinedd to each branch of the tree
    x = jax.tree_util.tree_map(
        lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree
    )

    return x, random_key

omg_mega_emitter

OMGMEGAEmitterState (EmitterState) dataclass

Emitter state for the CMA-MEGA emitter.

Parameters:
  • gradients_repertoire (MapElitesRepertoire) – MapElites repertoire containing the gradients of the indivuals.

Source code in qdax/core/emitters/omg_mega_emitter.py
class OMGMEGAEmitterState(EmitterState):
    """
    Emitter state for the CMA-MEGA emitter.

    Args:
        gradients_repertoire: MapElites repertoire containing the gradients
            of the indivuals.
    """

    gradients_repertoire: MapElitesRepertoire
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/omg_mega_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

OMGMEGAEmitter (Emitter)

Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by Fontaine et al.

NOTE: in order to implement this emitter while staying in the MAPElites framework, we had to make two temporary design choices: - in the emit function, we use the same random key to sample from the genotypes and gradients repertoire, in order to get the gradients that correspond to the right genotypes. Although acceptable, this is definitely not the best coding practice and we would prefer to get rid of this in a future version. A solution that we are discussing with the development team is to decompose the sampling function of the repertoire into two phases: one sampling the indices to be sampled, the other one retrieving the corresponding elements. This would enable to reuse the indices instead of doing this double sampling. - in the state_update, we have to insert the gradients in the gradients repertoire in the same way the individuals were inserted. Once again, this is slightly unoptimal because the same addition mecanism has to be computed two times. One solution that we are discussing and that is very similar to the first solution discussed above, would be to decompose the addition mecanism in two phases: one outputing the indices at which individuals will be added, and then the actual insertion step. This would enable to re-use the same indices to add the gradients instead of having to recompute them.

The two design choices seem acceptable and enable to have OMG MEGA compatible with the current implementation of the MAPElites and MAPElitesRepertoire classes.

Our suggested solutions seem quite simple and are likely to be useful for other variants implementation. They will be further discussed with the development team and potentially added in a future version of the package.

Source code in qdax/core/emitters/omg_mega_emitter.py
class OMGMEGAEmitter(Emitter):
    """
    Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by
    Fontaine et al.

    NOTE: in order to implement this emitter while staying in the MAPElites
    framework, we had to make two temporary design choices:
    - in the emit function, we use the same random key to sample from the
    genotypes and gradients repertoire, in order to get the gradients that
    correspond to the right genotypes. Although acceptable, this is definitely
    not the best coding practice and we would prefer to get rid of this in a
    future version. A solution that we are discussing with the development team
    is to decompose the sampling function of the repertoire into two phases: one
    sampling the indices to be sampled, the other one retrieving the corresponding
    elements. This would enable to reuse the indices instead of doing this double
    sampling.
    - in the state_update, we have to insert the gradients in the gradients
    repertoire in the same way the individuals were inserted. Once again, this is
    slightly unoptimal because the same addition mecanism has to be computed two
    times. One solution that we are discussing and that is very similar to the first
    solution discussed above, would be to decompose the addition mecanism in two
    phases: one outputing the indices at which individuals will be added, and then
    the actual insertion step. This would enable to re-use the same indices to add
    the gradients instead of having to recompute them.

    The two design choices seem acceptable and enable to have OMG MEGA compatible
    with the current implementation of the MAPElites and MAPElitesRepertoire classes.

    Our suggested solutions seem quite simple and are likely to be useful for other
    variants implementation. They will be further discussed with the development team
    and potentially added in a future version of the package.
    """

    def __init__(
        self,
        batch_size: int,
        sigma_g: float,
        num_descriptors: int,
        centroids: Centroid,
    ):
        """Creates an instance of the OMGMEGAEmitter class.

        Args:
            batch_size: number of solutions sampled at each iteration
            sigma_g: CAUTION - square of the standard deviation for the coefficients.
                This notation can be misleading as, although it's called sigma, it
                refers to the variance and not the standard deviation.
            num_descriptors: number of descriptors
            centroids: centroids used to create the repertoire of solutions.
                This will be used to create the repertoire of gradients.
        """
        # set the mean of the coeff distribution to zero
        self._mu = jnp.zeros(num_descriptors + 1)

        # set the cov matrix to sigma * I
        self._sigma = jnp.eye(num_descriptors + 1) * sigma_g

        # define other parameters of the distribution
        self._batch_size = batch_size
        self._centroids = centroids
        self._num_descriptors = num_descriptors

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[OMGMEGAEmitterState, RNGKey]:
        """Initialises the state of the emitter. Creates an empty repertoire
        that will later contain the gradients of the individuals.

        Args:
            init_genotypes: The genotypes of the initial population.
            random_key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state.
        """
        # retrieve one genotype from the population
        first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)

        # add a dimension of size num descriptors + 1
        gradient_genotype = jax.tree_util.tree_map(
            lambda x: jnp.repeat(
                jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
            ),
            first_genotype,
        )

        # create the gradients repertoire
        gradients_repertoire = MapElitesRepertoire.init_default(
            genotype=gradient_genotype, centroids=self._centroids
        )

        return (
            OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire),
            random_key,
        )

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: OMGMEGAEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        OMG emitter function that samples elements in the repertoire and does a gradient
        update with random coefficients to create new candidates.

        Args:
            repertoire: current repertoire
            emitter_state: current emitter state, contains the gradients
            random_key: random key

        Returns:
            new_genotypes: new candidates to be added to the grid
            random_key: updated random key
        """
        # sample genotypes
        (
            genotypes,
            _,
        ) = repertoire.sample(random_key, num_samples=self._batch_size)

        # sample gradients - use the same random key for sampling
        # See class docstrings for discussion about this choice
        gradients, random_key = emitter_state.gradients_repertoire.sample(
            random_key, num_samples=self._batch_size
        )

        fitness_gradients = jax.tree_util.tree_map(
            lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
        )
        descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients)

        # Normalize the gradients
        norm_fitness_gradients = jnp.linalg.norm(
            fitness_gradients, axis=1, keepdims=True
        )

        fitness_gradients = fitness_gradients / norm_fitness_gradients

        norm_descriptors_gradients = jnp.linalg.norm(
            descriptors_gradients, axis=1, keepdims=True
        )
        descriptors_gradients = descriptors_gradients / norm_descriptors_gradients

        # Draw random coefficients
        random_key, subkey = jax.random.split(random_key)
        coeffs = jax.random.multivariate_normal(
            subkey,
            shape=(self._batch_size,),
            mean=self._mu,
            cov=self._sigma,
        )
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        grads = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate((x, y), axis=-1),
            fitness_gradients,
            descriptors_gradients,
        )
        update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

        # update the genotypes
        new_genotypes = jax.tree_util.tree_map(
            lambda x, y: x + y, genotypes, update_grad
        )

        return new_genotypes, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def state_update(
        self,
        emitter_state: OMGMEGAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> OMGMEGAEmitterState:
        """Update the gradients repertoire to have the right gradients.

        NOTE: see discussion in the class docstrings to see how this could
        be improved.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """

        # get gradients out of the extra scores
        assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]

        # update the gradients repertoire
        gradients_repertoire = emitter_state.gradients_repertoire.add(
            gradients,
            descriptors,
            fitnesses,
            extra_scores,
        )

        return emitter_state.replace(  # type: ignore
            gradients_repertoire=gradients_repertoire
        )

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

__init__(self, batch_size, sigma_g, num_descriptors, centroids) special

Creates an instance of the OMGMEGAEmitter class.

Parameters:
  • batch_size (int) – number of solutions sampled at each iteration

  • sigma_g (float) – CAUTION - square of the standard deviation for the coefficients. This notation can be misleading as, although it's called sigma, it refers to the variance and not the standard deviation.

  • num_descriptors (int) – number of descriptors

  • centroids (Array) – centroids used to create the repertoire of solutions. This will be used to create the repertoire of gradients.

Source code in qdax/core/emitters/omg_mega_emitter.py
def __init__(
    self,
    batch_size: int,
    sigma_g: float,
    num_descriptors: int,
    centroids: Centroid,
):
    """Creates an instance of the OMGMEGAEmitter class.

    Args:
        batch_size: number of solutions sampled at each iteration
        sigma_g: CAUTION - square of the standard deviation for the coefficients.
            This notation can be misleading as, although it's called sigma, it
            refers to the variance and not the standard deviation.
        num_descriptors: number of descriptors
        centroids: centroids used to create the repertoire of solutions.
            This will be used to create the repertoire of gradients.
    """
    # set the mean of the coeff distribution to zero
    self._mu = jnp.zeros(num_descriptors + 1)

    # set the cov matrix to sigma * I
    self._sigma = jnp.eye(num_descriptors + 1) * sigma_g

    # define other parameters of the distribution
    self._batch_size = batch_size
    self._centroids = centroids
    self._num_descriptors = num_descriptors
init(self, init_genotypes, random_key)

Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The genotypes of the initial population.

  • random_key (Array) – a random key to handle stochastic operations.

Returns:
  • Tuple[qdax.core.emitters.omg_mega_emitter.OMGMEGAEmitterState, jax.Array] – The initial emitter state.

Source code in qdax/core/emitters/omg_mega_emitter.py
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[OMGMEGAEmitterState, RNGKey]:
    """Initialises the state of the emitter. Creates an empty repertoire
    that will later contain the gradients of the individuals.

    Args:
        init_genotypes: The genotypes of the initial population.
        random_key: a random key to handle stochastic operations.

    Returns:
        The initial emitter state.
    """
    # retrieve one genotype from the population
    first_genotype = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)

    # add a dimension of size num descriptors + 1
    gradient_genotype = jax.tree_util.tree_map(
        lambda x: jnp.repeat(
            jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
        ),
        first_genotype,
    )

    # create the gradients repertoire
    gradients_repertoire = MapElitesRepertoire.init_default(
        genotype=gradient_genotype, centroids=self._centroids
    )

    return (
        OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire),
        random_key,
    )
emit(self, repertoire, emitter_state, random_key)

OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates.

Parameters:
  • repertoire (MapElitesRepertoire) – current repertoire

  • emitter_state (OMGMEGAEmitterState) – current emitter state, contains the gradients

  • random_key (Array) – random key

Returns:
  • new_genotypes – new candidates to be added to the grid random_key: updated random key

Source code in qdax/core/emitters/omg_mega_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: OMGMEGAEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    OMG emitter function that samples elements in the repertoire and does a gradient
    update with random coefficients to create new candidates.

    Args:
        repertoire: current repertoire
        emitter_state: current emitter state, contains the gradients
        random_key: random key

    Returns:
        new_genotypes: new candidates to be added to the grid
        random_key: updated random key
    """
    # sample genotypes
    (
        genotypes,
        _,
    ) = repertoire.sample(random_key, num_samples=self._batch_size)

    # sample gradients - use the same random key for sampling
    # See class docstrings for discussion about this choice
    gradients, random_key = emitter_state.gradients_repertoire.sample(
        random_key, num_samples=self._batch_size
    )

    fitness_gradients = jax.tree_util.tree_map(
        lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
    )
    descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients)

    # Normalize the gradients
    norm_fitness_gradients = jnp.linalg.norm(
        fitness_gradients, axis=1, keepdims=True
    )

    fitness_gradients = fitness_gradients / norm_fitness_gradients

    norm_descriptors_gradients = jnp.linalg.norm(
        descriptors_gradients, axis=1, keepdims=True
    )
    descriptors_gradients = descriptors_gradients / norm_descriptors_gradients

    # Draw random coefficients
    random_key, subkey = jax.random.split(random_key)
    coeffs = jax.random.multivariate_normal(
        subkey,
        shape=(self._batch_size,),
        mean=self._mu,
        cov=self._sigma,
    )
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
    grads = jax.tree_util.tree_map(
        lambda x, y: jnp.concatenate((x, y), axis=-1),
        fitness_gradients,
        descriptors_gradients,
    )
    update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

    # update the genotypes
    new_genotypes = jax.tree_util.tree_map(
        lambda x, y: x + y, genotypes, update_grad
    )

    return new_genotypes, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Update the gradients repertoire to have the right gradients.

NOTE: see discussion in the class docstrings to see how this could be improved.

Parameters:
  • emitter_state (OMGMEGAEmitterState) – current emitter state

  • repertoire (MapElitesRepertoire) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the genotypes of the batch of emitted offspring.

  • fitnesses (Array) – the fitnesses of the batch of emitted offspring.

  • descriptors (Array) – the descriptors of the emitted offspring.

  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – a dictionary with other values outputted by the scoring function.

Returns:
  • OMGMEGAEmitterState – The modified emitter state.

Source code in qdax/core/emitters/omg_mega_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def state_update(
    self,
    emitter_state: OMGMEGAEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
    """Update the gradients repertoire to have the right gradients.

    NOTE: see discussion in the class docstrings to see how this could
    be improved.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring.
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: a dictionary with other values outputted by the
            scoring function.

    Returns:
        The modified emitter state.
    """

    # get gradients out of the extra scores
    assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
    gradients = extra_scores["gradients"]

    # update the gradients repertoire
    gradients_repertoire = emitter_state.gradients_repertoire.add(
        gradients,
        descriptors,
        fitnesses,
        extra_scores,
    )

    return emitter_state.replace(  # type: ignore
        gradients_repertoire=gradients_repertoire
    )

pbt_me_emitter

PBTEmitterState (EmitterState) dataclass

PBT emitter state contains the replay buffers that will be used by the population as well as the population agents training states and their starting environment state.

Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitterState(EmitterState):
    """
    PBT emitter state contains the replay buffers that will be used by the population as
    well as the population agents training states and their starting environment state.
    """

    replay_buffers: ReplayBuffer
    env_states: EnvState
    training_states: PBTTrainingState
    random_key: RNGKey
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/pbt_me_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

PBTEmitterConfig (PyTreeNode) dataclass

Config for the PBT-ME emitter. This mainly corresponds to the hyperparameters of the PBT-ME algorithm.

Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitterConfig(PyTreeNode):
    """
    Config for the PBT-ME emitter. This mainly corresponds to the hyperparameters
    of the PBT-ME algorithm.
    """

    buffer_size: int
    num_training_iterations: int
    env_batch_size: int
    grad_updates_per_step: int
    pg_population_size_per_device: int
    ga_population_size_per_device: int
    num_devices: int

    fraction_best_to_replace_from: float
    fraction_to_replace_from_best: float
    fraction_to_replace_from_samples: float
    # this fraction is used only for transfer between devices
    fraction_sort_exchange: float
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/pbt_me_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

PBTEmitter (Emitter)

A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.

Source code in qdax/core/emitters/pbt_me_emitter.py
class PBTEmitter(Emitter):
    """
    A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
    (PGA-Map-Elites) algorithm.
    """

    def __init__(
        self,
        pbt_agent: Union[PBTSAC, PBTTD3],
        config: PBTEmitterConfig,
        env: QDEnv,
        variation_fn: Callable[[Params, Params, RNGKey], Tuple[Params, RNGKey]],
    ) -> None:

        # Parameters internalization
        self._env = env
        self._variation_fn = variation_fn
        self._config = config
        self._agent = pbt_agent
        self._train_fn = self._agent.get_train_fn(
            env=env,
            num_iterations=config.num_training_iterations,
            env_batch_size=config.env_batch_size,
            grad_updates_per_step=config.grad_updates_per_step,
        )

        # Compute numbers from fractions
        pg_population_size = config.pg_population_size_per_device * config.num_devices
        self._num_best_to_replace_from = int(
            pg_population_size * config.fraction_best_to_replace_from
        )
        self._num_to_replace_from_best = int(
            pg_population_size * config.fraction_to_replace_from_best
        )
        self._num_to_replace_from_samples = int(
            pg_population_size * config.fraction_to_replace_from_samples
        )
        self._num_to_exchange = int(
            config.pg_population_size_per_device * config.fraction_sort_exchange
        )

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[PBTEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the PGAMEEmitter, a new random key.
        """

        observation_size = self._env.observation_size
        action_size = self._env.action_size

        # Initialise replay buffers
        init_dummy_transition = partial(
            Transition.init_dummy,
            observation_dim=observation_size,
            action_dim=action_size,
        )
        init_dummy_transition = jax.vmap(
            init_dummy_transition, axis_size=self._config.pg_population_size_per_device
        )
        dummy_transitions = init_dummy_transition()

        replay_buffer_init = partial(
            ReplayBuffer.init,
            buffer_size=self._config.buffer_size,
        )
        replay_buffer_init = jax.vmap(replay_buffer_init)
        replay_buffers = replay_buffer_init(transition=dummy_transitions)

        # Initialise env states
        (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
        env_states = jax.jit(self._env.reset)(rng=subkey1)

        reshape_fn = jax.jit(
            lambda tree: jax.tree_util.tree_map(
                lambda x: jnp.reshape(
                    x,
                    (
                        self._config.pg_population_size_per_device,
                        self._config.env_batch_size,
                    )
                    + x.shape[1:],
                ),
                tree,
            ),
        )
        env_states = reshape_fn(env_states)

        # Create emitter state
        # keep only pg population size training states if more are provided
        init_genotypes = jax.tree_util.tree_map(
            lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
        )
        emitter_state = PBTEmitterState(
            replay_buffers=replay_buffers,
            env_states=env_states,
            training_states=init_genotypes,
            random_key=subkey2,
        )

        return emitter_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: PBTEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Do a single PGA-ME iteration: train critics and greedy policy,
        make mutations (evo and pg), score solution, fill replay buffer and insert back
        in the MAP-Elites grid.

        Args:
            repertoire: the current repertoire of genotypes
            emitter_state: the state of the emitter used
            random_key: a random key

        Returns:
            A batch of offspring, the new emitter state and a new key.
        """

        # Mutation PG (the mutation has already been performed during the state update)
        x_mutation_pg = emitter_state.training_states

        # Mutation evo
        if self._config.ga_population_size_per_device > 0:
            mutation_ga_batch_size = self._config.ga_population_size_per_device
            x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
            x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
            x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)

            # Gather offspring
            genotypes = jax.tree_util.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                x_mutation_ga,
                x_mutation_pg,
            )
        else:
            genotypes = x_mutation_pg

        return genotypes, random_key

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        mutation_pg_batch_size = self._config.pg_population_size_per_device
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        return mutation_pg_batch_size + mutation_ga_batch_size

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: PBTEmitterState,
        repertoire: Repertoire,
        genotypes: Optional[Genotype],
        fitnesses: Fitness,
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> PBTEmitterState:
        """
        Update the internal emitter state. I.e. update the population replay buffers and
        agents.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # Look only at the fitness corresponding to emitter state individuals
        fitnesses = fitnesses[self._config.ga_population_size_per_device :]
        fitnesses = jnp.ravel(fitnesses)
        training_states = emitter_state.training_states
        replay_buffers = emitter_state.replay_buffers
        genotypes = (training_states, replay_buffers)

        # Incremental algorithm to gather top best among the population on each device
        # First exchange
        indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
        num_best_local = int(
            self._config.pg_population_size_per_device
            * self._config.fraction_best_to_replace_from
        )
        indices_to_share = indices_to_share[:num_best_local]
        genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
        best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
            lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
        )

        # Define loop fn for the other exchanges
        def _loop_fn(i, val):  # type: ignore
            best_genotypes_local, best_fitnesses_local = val
            indices_to_share = jax.lax.dynamic_slice(
                jnp.arange(self._config.pg_population_size_per_device),
                [i * self._num_to_exchange],
                [self._num_to_exchange],
            )
            genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
                lambda x: x[indices_to_share], (genotypes, fitnesses)
            )
            gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
                lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
                (genotypes_to_share, fitnesses_to_share),
            )

            genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
                lambda x, y: jnp.concatenate([x, y], axis=0),
                (gathered_genotypes, gathered_fitnesses),
                (best_genotypes_local, best_fitnesses_local),
            )

            best_indices_stacked = jnp.argsort(-fitnesses_stacked)
            best_indices_stacked = best_indices_stacked[
                : self._num_best_to_replace_from
            ]
            best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
                lambda x: x[best_indices_stacked],
                (genotypes_stacked, fitnesses_stacked),
            )
            return (best_genotypes_local, best_fitnesses_local)  # type: ignore

        # Incrementally get the top fraction_best_to_replace_from best individuals
        # on each device
        (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
            lower=1,
            upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
            body_fun=_loop_fn,
            init_val=(best_genotypes_local, best_fitnesses_local),
        )

        # Gather fitnesses from all devices to rank locally against it
        all_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            fitnesses,
        )
        all_fitnesses = jnp.ravel(all_fitnesses)
        all_fitnesses = -jnp.sort(-all_fitnesses)
        random_key = emitter_state.random_key
        random_key, sub_key = jax.random.split(random_key)
        best_genotypes = jax.tree_util.tree_map(
            lambda x: jax.random.choice(
                sub_key, x, shape=(len(fitnesses),), replace=True
            ),
            best_genotypes_local,
        )
        best_training_states, best_replay_buffers = best_genotypes

        # Resample hyper-params
        best_training_states = jax.vmap(
            best_training_states.__class__.resample_hyperparams
        )(best_training_states)

        # Replace by individuals from the best
        lower_bound = all_fitnesses[-self._num_to_replace_from_best]
        cond = fitnesses <= lower_bound

        training_states = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_training_states,
            training_states,
        )
        replay_buffers = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            best_replay_buffers,
            replay_buffers,
        )

        # Replacing with samples from the ME repertoire
        if self._num_to_replace_from_samples > 0:
            me_samples, random_key = repertoire.sample(
                random_key, self._config.pg_population_size_per_device
            )
            # Resample hyper-params
            me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
            upper_bound = all_fitnesses[
                -self._num_to_replace_from_best - self._num_to_replace_from_samples
            ]
            cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
            training_states = jax.tree_util.tree_map(
                lambda x, y: jnp.where(
                    jnp.expand_dims(
                        cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                    ),
                    x,
                    y,
                ),
                me_samples,
                training_states,
            )

        # Train the agents
        env_states = emitter_state.env_states
        # Init optimizers state before training the population
        training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
            training_states
        )
        (training_states, env_states, replay_buffers), metrics = self._train_fn(
            training_states, env_states, replay_buffers
        )
        # Empty optimizers states to avoid storing the info in the RAM
        # and having too heavy repertoires
        training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
            training_states
        )

        # Update emitter state
        emitter_state = emitter_state.replace(
            training_states=training_states,
            replay_buffers=replay_buffers,
            env_states=env_states,
            random_key=random_key,
        )
        return emitter_state  # type: ignore
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The initial population.

  • random_key (Array) – A random key.

Returns:
  • Tuple[qdax.core.emitters.pbt_me_emitter.PBTEmitterState, jax.Array] – The initial state of the PGAMEEmitter, a new random key.

Source code in qdax/core/emitters/pbt_me_emitter.py
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[PBTEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the PGAMEEmitter, a new random key.
    """

    observation_size = self._env.observation_size
    action_size = self._env.action_size

    # Initialise replay buffers
    init_dummy_transition = partial(
        Transition.init_dummy,
        observation_dim=observation_size,
        action_dim=action_size,
    )
    init_dummy_transition = jax.vmap(
        init_dummy_transition, axis_size=self._config.pg_population_size_per_device
    )
    dummy_transitions = init_dummy_transition()

    replay_buffer_init = partial(
        ReplayBuffer.init,
        buffer_size=self._config.buffer_size,
    )
    replay_buffer_init = jax.vmap(replay_buffer_init)
    replay_buffers = replay_buffer_init(transition=dummy_transitions)

    # Initialise env states
    (random_key, subkey1, subkey2) = jax.random.split(random_key, num=3)
    env_states = jax.jit(self._env.reset)(rng=subkey1)

    reshape_fn = jax.jit(
        lambda tree: jax.tree_util.tree_map(
            lambda x: jnp.reshape(
                x,
                (
                    self._config.pg_population_size_per_device,
                    self._config.env_batch_size,
                )
                + x.shape[1:],
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)

    # Create emitter state
    # keep only pg population size training states if more are provided
    init_genotypes = jax.tree_util.tree_map(
        lambda x: x[: self._config.pg_population_size_per_device], init_genotypes
    )
    emitter_state = PBTEmitterState(
        replay_buffers=replay_buffers,
        env_states=env_states,
        training_states=init_genotypes,
        random_key=subkey2,
    )

    return emitter_state, random_key
emit(self, repertoire, emitter_state, random_key)

Do a single PGA-ME iteration: train critics and greedy policy, make mutations (evo and pg), score solution, fill replay buffer and insert back in the MAP-Elites grid.

Parameters:
  • repertoire (Repertoire) – the current repertoire of genotypes

  • emitter_state (PBTEmitterState) – the state of the emitter used

  • random_key (Array) – a random key

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A batch of offspring, the new emitter state and a new key.

Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: Repertoire,
    emitter_state: PBTEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Do a single PGA-ME iteration: train critics and greedy policy,
    make mutations (evo and pg), score solution, fill replay buffer and insert back
    in the MAP-Elites grid.

    Args:
        repertoire: the current repertoire of genotypes
        emitter_state: the state of the emitter used
        random_key: a random key

    Returns:
        A batch of offspring, the new emitter state and a new key.
    """

    # Mutation PG (the mutation has already been performed during the state update)
    x_mutation_pg = emitter_state.training_states

    # Mutation evo
    if self._config.ga_population_size_per_device > 0:
        mutation_ga_batch_size = self._config.ga_population_size_per_device
        x1, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
        x2, random_key = repertoire.sample(random_key, mutation_ga_batch_size)
        x_mutation_ga, random_key = self._variation_fn(x1, x2, random_key)

        # Gather offspring
        genotypes = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            x_mutation_ga,
            x_mutation_pg,
        )
    else:
        genotypes = x_mutation_pg

    return genotypes, random_key
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Update the internal emitter state. I.e. update the population replay buffers and agents.

Parameters:
  • emitter_state (PBTEmitterState) – current emitter state.

  • repertoire (Repertoire) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – unused here - but compulsory in the signature.

  • fitnesses (Array) – unused here - but compulsory in the signature.

  • descriptors (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – extra information coming from the scoring function, this contains the transitions added to the replay buffer.

Returns:
  • PBTEmitterState – New emitter state where the replay buffer has been filled with the new experienced transitions.

Source code in qdax/core/emitters/pbt_me_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
    self,
    emitter_state: PBTEmitterState,
    repertoire: Repertoire,
    genotypes: Optional[Genotype],
    fitnesses: Fitness,
    descriptors: Optional[Descriptor],
    extra_scores: ExtraScores,
) -> PBTEmitterState:
    """
    Update the internal emitter state. I.e. update the population replay buffers and
    agents.

    Args:
        emitter_state: current emitter state.
        repertoire: the current genotypes repertoire
        genotypes: unused here - but compulsory in the signature.
        fitnesses: unused here - but compulsory in the signature.
        descriptors: unused here - but compulsory in the signature.
        extra_scores: extra information coming from the scoring function,
            this contains the transitions added to the replay buffer.

    Returns:
        New emitter state where the replay buffer has been filled with
        the new experienced transitions.
    """
    # Look only at the fitness corresponding to emitter state individuals
    fitnesses = fitnesses[self._config.ga_population_size_per_device :]
    fitnesses = jnp.ravel(fitnesses)
    training_states = emitter_state.training_states
    replay_buffers = emitter_state.replay_buffers
    genotypes = (training_states, replay_buffers)

    # Incremental algorithm to gather top best among the population on each device
    # First exchange
    indices_to_share = jnp.arange(self._config.pg_population_size_per_device)
    num_best_local = int(
        self._config.pg_population_size_per_device
        * self._config.fraction_best_to_replace_from
    )
    indices_to_share = indices_to_share[:num_best_local]
    genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
        lambda x: x[indices_to_share], (genotypes, fitnesses)
    )
    gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes_to_share, fitnesses_to_share),
    )

    genotypes_stacked, fitnesses_stacked = gathered_genotypes, gathered_fitnesses
    best_indices_stacked = jnp.argsort(-fitnesses_stacked)
    best_indices_stacked = best_indices_stacked[: self._num_best_to_replace_from]
    best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
        lambda x: x[best_indices_stacked], (genotypes_stacked, fitnesses_stacked)
    )

    # Define loop fn for the other exchanges
    def _loop_fn(i, val):  # type: ignore
        best_genotypes_local, best_fitnesses_local = val
        indices_to_share = jax.lax.dynamic_slice(
            jnp.arange(self._config.pg_population_size_per_device),
            [i * self._num_to_exchange],
            [self._num_to_exchange],
        )
        genotypes_to_share, fitnesses_to_share = jax.tree_util.tree_map(
            lambda x: x[indices_to_share], (genotypes, fitnesses)
        )
        gathered_genotypes, gathered_fitnesses = jax.tree_util.tree_map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes_to_share, fitnesses_to_share),
        )

        genotypes_stacked, fitnesses_stacked = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            (gathered_genotypes, gathered_fitnesses),
            (best_genotypes_local, best_fitnesses_local),
        )

        best_indices_stacked = jnp.argsort(-fitnesses_stacked)
        best_indices_stacked = best_indices_stacked[
            : self._num_best_to_replace_from
        ]
        best_genotypes_local, best_fitnesses_local = jax.tree_util.tree_map(
            lambda x: x[best_indices_stacked],
            (genotypes_stacked, fitnesses_stacked),
        )
        return (best_genotypes_local, best_fitnesses_local)  # type: ignore

    # Incrementally get the top fraction_best_to_replace_from best individuals
    # on each device
    (best_genotypes_local, best_fitnesses_local) = jax.lax.fori_loop(
        lower=1,
        upper=int(1.0 // self._config.fraction_sort_exchange) + 1,
        body_fun=_loop_fn,
        init_val=(best_genotypes_local, best_fitnesses_local),
    )

    # Gather fitnesses from all devices to rank locally against it
    all_fitnesses = jax.tree_util.tree_map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        fitnesses,
    )
    all_fitnesses = jnp.ravel(all_fitnesses)
    all_fitnesses = -jnp.sort(-all_fitnesses)
    random_key = emitter_state.random_key
    random_key, sub_key = jax.random.split(random_key)
    best_genotypes = jax.tree_util.tree_map(
        lambda x: jax.random.choice(
            sub_key, x, shape=(len(fitnesses),), replace=True
        ),
        best_genotypes_local,
    )
    best_training_states, best_replay_buffers = best_genotypes

    # Resample hyper-params
    best_training_states = jax.vmap(
        best_training_states.__class__.resample_hyperparams
    )(best_training_states)

    # Replace by individuals from the best
    lower_bound = all_fitnesses[-self._num_to_replace_from_best]
    cond = fitnesses <= lower_bound

    training_states = jax.tree_util.tree_map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_training_states,
        training_states,
    )
    replay_buffers = jax.tree_util.tree_map(
        lambda x, y: jnp.where(
            jnp.expand_dims(
                cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
            ),
            x,
            y,
        ),
        best_replay_buffers,
        replay_buffers,
    )

    # Replacing with samples from the ME repertoire
    if self._num_to_replace_from_samples > 0:
        me_samples, random_key = repertoire.sample(
            random_key, self._config.pg_population_size_per_device
        )
        # Resample hyper-params
        me_samples = jax.vmap(me_samples.__class__.resample_hyperparams)(me_samples)
        upper_bound = all_fitnesses[
            -self._num_to_replace_from_best - self._num_to_replace_from_samples
        ]
        cond = jnp.logical_and(fitnesses <= upper_bound, fitnesses >= lower_bound)
        training_states = jax.tree_util.tree_map(
            lambda x, y: jnp.where(
                jnp.expand_dims(
                    cond, axis=tuple([-(i + 1) for i in range(x.ndim - 1)])
                ),
                x,
                y,
            ),
            me_samples,
            training_states,
        )

    # Train the agents
    env_states = emitter_state.env_states
    # Init optimizers state before training the population
    training_states = jax.vmap(training_states.__class__.init_optimizers_states)(
        training_states
    )
    (training_states, env_states, replay_buffers), metrics = self._train_fn(
        training_states, env_states, replay_buffers
    )
    # Empty optimizers states to avoid storing the info in the RAM
    # and having too heavy repertoires
    training_states = jax.vmap(training_states.__class__.empty_optimizers_states)(
        training_states
    )

    # Update emitter state
    emitter_state = emitter_state.replace(
        training_states=training_states,
        replay_buffers=replay_buffers,
        env_states=env_states,
        random_key=random_key,
    )
    return emitter_state  # type: ignore

pbt_variation_operators

sac_pbt_variation_fn(training_state1, training_state2, random_key, iso_sigma, line_sigma)

This operator runs a cross-over between two SAC agents. It is used as variation operator in the SAC-PBT-Map-Elites algorithm. An isoline-dd variation is applied to policy networks, critic networks and entropy alpha coefficients.

Parameters:
  • training_state1 (PBTSacTrainingState) – Training state of first SAC agent.

  • training_state2 (PBTSacTrainingState) – Training state of first SAC agent.

  • random_key (Array) – Random key.

  • iso_sigma (float) – Spread parameter (noise).

  • line_sigma (float) – Line parameter (direction of the new genotype).

Returns:
  • Tuple[qdax.baselines.sac_pbt.PBTSacTrainingState, jax.Array] – A new SAC training state obtained from cross-over and an updated random key.

Source code in qdax/core/emitters/pbt_variation_operators.py
def sac_pbt_variation_fn(
    training_state1: PBTSacTrainingState,
    training_state2: PBTSacTrainingState,
    random_key: RNGKey,
    iso_sigma: float,
    line_sigma: float,
) -> Tuple[PBTSacTrainingState, RNGKey]:
    """
    This operator runs a cross-over between two SAC agents. It is used as variation
    operator in the SAC-PBT-Map-Elites algorithm. An isoline-dd variation is applied
    to policy networks, critic networks and entropy alpha coefficients.

    Args:
        training_state1: Training state of first SAC agent.
        training_state2: Training state of first SAC agent.
        random_key: Random key.
        iso_sigma: Spread parameter (noise).
        line_sigma: Line parameter (direction of the new genotype).

    Returns:
        A new SAC training state obtained from cross-over and an updated random key.

    """

    policy_params1, policy_params2 = (
        training_state1.policy_params,
        training_state2.policy_params,
    )
    critic_params1, critic_params2 = (
        training_state1.critic_params,
        training_state2.critic_params,
    )
    alpha_params1, alpha_params2 = (
        training_state1.alpha_params,
        training_state2.alpha_params,
    )
    (policy_params, critic_params, alpha_params), random_key = isoline_variation(
        x1=(policy_params1, critic_params1, alpha_params1),
        x2=(policy_params2, critic_params2, alpha_params2),
        random_key=random_key,
        iso_sigma=iso_sigma,
        line_sigma=line_sigma,
    )

    new_training_state = training_state1.replace(
        policy_params=policy_params,
        critic_params=critic_params,
        alpha_params=alpha_params,
    )

    return (
        new_training_state,
        random_key,
    )

td3_pbt_variation_fn(training_state1, training_state2, random_key, iso_sigma, line_sigma)

This operator runs a cross-over between two TD3 agents. It is used as variation operator in the TD3-PBT-Map-Elites algorithm. An isoline-dd variation is applied to policy networks and critic networks.

Parameters:
  • training_state1 (PBTTD3TrainingState) – Training state of first TD3 agent.

  • training_state2 (PBTTD3TrainingState) – Training state of first TD3 agent.

  • random_key (Array) – Random key.

  • iso_sigma (float) – Spread parameter (noise).

  • line_sigma (float) – Line parameter (direction of the new genotype).

Returns:
  • Tuple[qdax.baselines.td3_pbt.PBTTD3TrainingState, jax.Array] – A new TD3 training state obtained from cross-over and an updated random key.

Source code in qdax/core/emitters/pbt_variation_operators.py
def td3_pbt_variation_fn(
    training_state1: PBTTD3TrainingState,
    training_state2: PBTTD3TrainingState,
    random_key: RNGKey,
    iso_sigma: float,
    line_sigma: float,
) -> Tuple[PBTTD3TrainingState, RNGKey]:
    """
    This operator runs a cross-over between two TD3 agents. It is used as variation
    operator in the TD3-PBT-Map-Elites algorithm. An isoline-dd variation is applied
    to policy networks and critic networks.

    Args:
        training_state1: Training state of first TD3 agent.
        training_state2: Training state of first TD3 agent.
        random_key: Random key.
        iso_sigma: Spread parameter (noise).
        line_sigma: Line parameter (direction of the new genotype).

    Returns:
        A new TD3 training state obtained from cross-over and an updated random key.

    """

    policy_params1, policy_params2 = (
        training_state1.policy_params,
        training_state2.policy_params,
    )
    critic_params1, critic_params2 = (
        training_state1.critic_params,
        training_state2.critic_params,
    )
    (policy_params, critic_params,), random_key = isoline_variation(
        x1=(policy_params1, critic_params1),
        x2=(policy_params2, critic_params2),
        random_key=random_key,
        iso_sigma=iso_sigma,
        line_sigma=line_sigma,
    )
    new_training_state = training_state1.replace(
        policy_params=policy_params,
        critic_params=critic_params,
    )

    return (
        new_training_state,
        random_key,
    )

pga_me_emitter

PGAMEConfig dataclass

Configuration for PGAME Algorithm

Source code in qdax/core/emitters/pga_me_emitter.py
@dataclass
class PGAMEConfig:
    """Configuration for PGAME Algorithm"""

    env_batch_size: int = 100
    proportion_mutation_ga: float = 0.5
    num_critic_training_steps: int = 300
    num_pg_training_steps: int = 100

    # TD3 params
    replay_buffer_size: int = 1000000
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
    critic_learning_rate: float = 3e-4
    greedy_learning_rate: float = 3e-4
    policy_learning_rate: float = 1e-3
    noise_clip: float = 0.5
    policy_noise: float = 0.2
    discount: float = 0.99
    reward_scaling: float = 1.0
    batch_size: int = 256
    soft_tau_update: float = 0.005
    policy_delay: int = 2

qdpg_emitter

Implementation of an updated version of the algorithm QDPG presented in the paper https://arxiv.org/abs/2006.08505.

QDPG has been udpated to enter in the container+emitter framework of QD. Furthermore, it has been updated to work better with Jax in term of time cost. Those changes have been made in accordance with the authors of this algorithm.

QDPGEmitterConfig dataclass

QDPGEmitterConfig(qpg_config: qdax.core.emitters.qpg_emitter.QualityPGConfig, dpg_config: qdax.core.emitters.dpg_emitter.DiversityPGConfig, iso_sigma: float, line_sigma: float, ga_batch_size: int)

Source code in qdax/core/emitters/qdpg_emitter.py
@dataclass
class QDPGEmitterConfig:
    qpg_config: QualityPGConfig
    dpg_config: DiversityPGConfig
    iso_sigma: float
    line_sigma: float
    ga_batch_size: int

qpg_emitter

Implements the PG Emitter from PGA-ME algorithm in jax for brax environments, based on: https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf

QualityPGConfig dataclass

Configuration for QualityPG Emitter

Source code in qdax/core/emitters/qpg_emitter.py
@dataclass
class QualityPGConfig:
    """Configuration for QualityPG Emitter"""

    env_batch_size: int = 100
    num_critic_training_steps: int = 300
    num_pg_training_steps: int = 100

    # TD3 params
    replay_buffer_size: int = 1000000
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256)
    critic_learning_rate: float = 3e-4
    actor_learning_rate: float = 3e-4
    policy_learning_rate: float = 1e-3
    noise_clip: float = 0.5
    policy_noise: float = 0.2
    discount: float = 0.99
    reward_scaling: float = 1.0
    batch_size: int = 256
    soft_tau_update: float = 0.005
    policy_delay: int = 2

QualityPGEmitterState (EmitterState) dataclass

Contains training state for the learner.

Source code in qdax/core/emitters/qpg_emitter.py
class QualityPGEmitterState(EmitterState):
    """Contains training state for the learner."""

    critic_params: Params
    critic_optimizer_state: optax.OptState
    actor_params: Params
    actor_opt_state: optax.OptState
    target_critic_params: Params
    target_actor_params: Params
    replay_buffer: ReplayBuffer
    random_key: RNGKey
    steps: jnp.ndarray
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/emitters/qpg_emitter.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

QualityPGEmitter (Emitter)

A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites (PGA-Map-Elites) algorithm.

Source code in qdax/core/emitters/qpg_emitter.py
class QualityPGEmitter(Emitter):
    """
    A policy gradient emitter used to implement the Policy Gradient Assisted MAP-Elites
    (PGA-Map-Elites) algorithm.
    """

    def __init__(
        self,
        config: QualityPGConfig,
        policy_network: nn.Module,
        env: QDEnv,
    ) -> None:
        self._config = config
        self._env = env
        self._policy_network = policy_network

        # Init Critics
        critic_network = QModule(
            n_critics=2, hidden_layer_sizes=self._config.critic_hidden_layer_size
        )
        self._critic_network = critic_network

        # Set up the losses and optimizers - return the opt states
        self._policy_loss_fn, self._critic_loss_fn = make_td3_loss_fn(
            policy_fn=policy_network.apply,
            critic_fn=critic_network.apply,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            noise_clip=self._config.noise_clip,
            policy_noise=self._config.policy_noise,
        )

        # Init optimizers
        self._actor_optimizer = optax.adam(
            learning_rate=self._config.actor_learning_rate
        )
        self._critic_optimizer = optax.adam(
            learning_rate=self._config.critic_learning_rate
        )
        self._policies_optimizer = optax.adam(
            learning_rate=self._config.policy_learning_rate
        )

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._config.env_batch_size

    @property
    def use_all_data(self) -> bool:
        """Whether to use all data or not when used along other emitters.

        QualityPGEmitter uses the transitions from the genotypes that were generated
        by other emitters.
        """
        return True

    def init(
        self, init_genotypes: Genotype, random_key: RNGKey
    ) -> Tuple[QualityPGEmitterState, RNGKey]:
        """Initializes the emitter state.

        Args:
            init_genotypes: The initial population.
            random_key: A random key.

        Returns:
            The initial state of the PGAMEEmitter, a new random key.
        """

        observation_size = self._env.observation_size
        action_size = self._env.action_size
        descriptor_size = self._env.state_descriptor_length

        # Initialise critic, greedy actor and population
        random_key, subkey = jax.random.split(random_key)
        fake_obs = jnp.zeros(shape=(observation_size,))
        fake_action = jnp.zeros(shape=(action_size,))
        critic_params = self._critic_network.init(
            subkey, obs=fake_obs, actions=fake_action
        )
        target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params)

        actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
        target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)

        # Prepare init optimizer states
        critic_optimizer_state = self._critic_optimizer.init(critic_params)
        actor_optimizer_state = self._actor_optimizer.init(actor_params)

        # Initialize replay buffer
        dummy_transition = QDTransition.init_dummy(
            observation_dim=observation_size,
            action_dim=action_size,
            descriptor_dim=descriptor_size,
        )

        replay_buffer = ReplayBuffer.init(
            buffer_size=self._config.replay_buffer_size, transition=dummy_transition
        )

        # Initial training state
        random_key, subkey = jax.random.split(random_key)
        emitter_state = QualityPGEmitterState(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            actor_params=actor_params,
            actor_opt_state=actor_optimizer_state,
            target_critic_params=target_critic_params,
            target_actor_params=target_actor_params,
            random_key=subkey,
            steps=jnp.array(0),
            replay_buffer=replay_buffer,
        )

        return emitter_state, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: QualityPGEmitterState,
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """Do a step of PG emission.

        Args:
            repertoire: the current repertoire of genotypes
            emitter_state: the state of the emitter used
            random_key: a random key

        Returns:
            A batch of offspring, the new emitter state and a new key.
        """

        batch_size = self._config.env_batch_size

        # sample parents
        mutation_pg_batch_size = int(batch_size - 1)
        parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size)

        # apply the pg mutation
        offsprings_pg = self.emit_pg(emitter_state, parents)

        # get the actor (greedy actor)
        offspring_actor = self.emit_actor(emitter_state)

        # add dimension for concatenation
        offspring_actor = jax.tree_util.tree_map(
            lambda x: jnp.expand_dims(x, axis=0), offspring_actor
        )

        # gather offspring
        genotypes = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            offsprings_pg,
            offspring_actor,
        )

        return genotypes, random_key

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit_pg(
        self, emitter_state: QualityPGEmitterState, parents: Genotype
    ) -> Genotype:
        """Emit the offsprings generated through pg mutation.

        Args:
            emitter_state: current emitter state, contains critic and
                replay buffer.
            parents: the parents selected to be applied gradients in order
                to mutate towards better performance.

        Returns:
            A new set of offsprings.
        """
        mutation_fn = partial(
            self._mutation_function_pg,
            emitter_state=emitter_state,
        )
        offsprings = jax.vmap(mutation_fn)(parents)

        return offsprings

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
        """Emit the greedy actor.

        Simply needs to be retrieved from the emitter state.

        Args:
            emitter_state: the current emitter state, it stores the
                greedy actor.

        Returns:
            The parameters of the actor.
        """
        return emitter_state.actor_params

    @partial(jax.jit, static_argnames=("self",))
    def state_update(
        self,
        emitter_state: QualityPGEmitterState,
        repertoire: Optional[Repertoire],
        genotypes: Optional[Genotype],
        fitnesses: Optional[Fitness],
        descriptors: Optional[Descriptor],
        extra_scores: ExtraScores,
    ) -> QualityPGEmitterState:
        """This function gives an opportunity to update the emitter state
        after the genotypes have been scored.

        Here it is used to fill the Replay Buffer with the transitions
        from the scoring of the genotypes, and then the training of the
        critic/actor happens. Hence the params of critic/actor are updated,
        as well as their optimizer states.

        Args:
            emitter_state: current emitter state.
            repertoire: the current genotypes repertoire
            genotypes: unused here - but compulsory in the signature.
            fitnesses: unused here - but compulsory in the signature.
            descriptors: unused here - but compulsory in the signature.
            extra_scores: extra information coming from the scoring function,
                this contains the transitions added to the replay buffer.

        Returns:
            New emitter state where the replay buffer has been filled with
            the new experienced transitions.
        """
        # get the transitions out of the dictionary
        assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
        transitions = extra_scores["transitions"]

        # add transitions in the replay buffer
        replay_buffer = emitter_state.replay_buffer.insert(transitions)
        emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

        def scan_train_critics(
            carry: QualityPGEmitterState, unused: Any
        ) -> Tuple[QualityPGEmitterState, Any]:
            emitter_state = carry
            new_emitter_state = self._train_critics(emitter_state)
            return new_emitter_state, ()

        # Train critics and greedy actor
        emitter_state, _ = jax.lax.scan(
            scan_train_critics,
            emitter_state,
            (),
            length=self._config.num_critic_training_steps,
        )

        return emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _train_critics(
        self, emitter_state: QualityPGEmitterState
    ) -> QualityPGEmitterState:
        """Apply one gradient step to critics and to the greedy actor
        (contained in carry in training_state), then soft update target critics
        and target actor.

        Those updates are very similar to those made in TD3.

        Args:
            emitter_state: actual emitter state

        Returns:
            New emitter state where the critic and the greedy actor have been
            updated. Optimizer states have also been updated in the process.
        """

        # Sample a batch of transitions in the buffer
        random_key = emitter_state.random_key
        replay_buffer = emitter_state.replay_buffer
        transitions, random_key = replay_buffer.sample(
            random_key, sample_size=self._config.batch_size
        )

        # Update Critic
        (
            critic_optimizer_state,
            critic_params,
            target_critic_params,
            random_key,
        ) = self._update_critic(
            critic_params=emitter_state.critic_params,
            target_critic_params=emitter_state.target_critic_params,
            target_actor_params=emitter_state.target_actor_params,
            critic_optimizer_state=emitter_state.critic_optimizer_state,
            transitions=transitions,
            random_key=random_key,
        )

        # Update greedy actor
        (actor_optimizer_state, actor_params, target_actor_params,) = jax.lax.cond(
            emitter_state.steps % self._config.policy_delay == 0,
            lambda x: self._update_actor(*x),
            lambda _: (
                emitter_state.actor_opt_state,
                emitter_state.actor_params,
                emitter_state.target_actor_params,
            ),
            operand=(
                emitter_state.actor_params,
                emitter_state.actor_opt_state,
                emitter_state.target_actor_params,
                emitter_state.critic_params,
                transitions,
            ),
        )

        # Create new training state
        new_emitter_state = emitter_state.replace(
            critic_params=critic_params,
            critic_optimizer_state=critic_optimizer_state,
            actor_params=actor_params,
            actor_opt_state=actor_optimizer_state,
            target_critic_params=target_critic_params,
            target_actor_params=target_actor_params,
            random_key=random_key,
            steps=emitter_state.steps + 1,
            replay_buffer=replay_buffer,
        )

        return new_emitter_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def _update_critic(
        self,
        critic_params: Params,
        target_critic_params: Params,
        target_actor_params: Params,
        critic_optimizer_state: Params,
        transitions: QDTransition,
        random_key: RNGKey,
    ) -> Tuple[Params, Params, Params, RNGKey]:

        # compute loss and gradients
        random_key, subkey = jax.random.split(random_key)
        critic_loss, critic_gradient = jax.value_and_grad(self._critic_loss_fn)(
            critic_params,
            target_actor_params,
            target_critic_params,
            transitions,
            subkey,
        )
        critic_updates, critic_optimizer_state = self._critic_optimizer.update(
            critic_gradient, critic_optimizer_state
        )

        # update critic
        critic_params = optax.apply_updates(critic_params, critic_updates)

        # Soft update of target critic network
        target_critic_params = jax.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            target_critic_params,
            critic_params,
        )

        return critic_optimizer_state, critic_params, target_critic_params, random_key

    @partial(jax.jit, static_argnames=("self",))
    def _update_actor(
        self,
        actor_params: Params,
        actor_opt_state: optax.OptState,
        target_actor_params: Params,
        critic_params: Params,
        transitions: QDTransition,
    ) -> Tuple[optax.OptState, Params, Params]:

        # Update greedy actor
        policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
            actor_params,
            critic_params,
            transitions,
        )
        (
            policy_updates,
            actor_optimizer_state,
        ) = self._actor_optimizer.update(policy_gradient, actor_opt_state)
        actor_params = optax.apply_updates(actor_params, policy_updates)

        # Soft update of target greedy actor
        target_actor_params = jax.tree_map(
            lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
            + self._config.soft_tau_update * x2,
            target_actor_params,
            actor_params,
        )

        return (
            actor_optimizer_state,
            actor_params,
            target_actor_params,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _mutation_function_pg(
        self,
        policy_params: Genotype,
        emitter_state: QualityPGEmitterState,
    ) -> Genotype:
        """Apply pg mutation to a policy via multiple steps of gradient descent.
        First, update the rewards to be diversity rewards, then apply the gradient
        steps.

        Args:
            policy_params: a policy, supposed to be a differentiable neural
                network.
            emitter_state: the current state of the emitter, containing among others,
                the replay buffer, the critic.

        Returns:
            The updated params of the neural network.
        """

        # Define new policy optimizer state
        policy_optimizer_state = self._policies_optimizer.init(policy_params)

        def scan_train_policy(
            carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState],
            unused: Any,
        ) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]:
            emitter_state, policy_params, policy_optimizer_state = carry
            (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ) = self._train_policy(
                emitter_state,
                policy_params,
                policy_optimizer_state,
            )
            return (
                new_emitter_state,
                new_policy_params,
                new_policy_optimizer_state,
            ), ()

        (emitter_state, policy_params, policy_optimizer_state,), _ = jax.lax.scan(
            scan_train_policy,
            (emitter_state, policy_params, policy_optimizer_state),
            (),
            length=self._config.num_pg_training_steps,
        )

        return policy_params

    @partial(jax.jit, static_argnames=("self",))
    def _train_policy(
        self,
        emitter_state: QualityPGEmitterState,
        policy_params: Params,
        policy_optimizer_state: optax.OptState,
    ) -> Tuple[QualityPGEmitterState, Params, optax.OptState]:
        """Apply one gradient step to a policy (called policy_params).

        Args:
            emitter_state: current state of the emitter.
            policy_params: parameters corresponding to the weights and bias of
                the neural network that defines the policy.

        Returns:
            The new emitter state and new params of the NN.
        """

        # Sample a batch of transitions in the buffer
        random_key = emitter_state.random_key
        replay_buffer = emitter_state.replay_buffer
        transitions, random_key = replay_buffer.sample(
            random_key, sample_size=self._config.batch_size
        )

        # update policy
        policy_optimizer_state, policy_params = self._update_policy(
            critic_params=emitter_state.critic_params,
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            transitions=transitions,
        )

        # Create new training state
        new_emitter_state = emitter_state.replace(
            random_key=random_key,
            replay_buffer=replay_buffer,
        )

        return new_emitter_state, policy_params, policy_optimizer_state

    @partial(jax.jit, static_argnames=("self",))
    def _update_policy(
        self,
        critic_params: Params,
        policy_optimizer_state: optax.OptState,
        policy_params: Params,
        transitions: QDTransition,
    ) -> Tuple[optax.OptState, Params]:

        # compute loss
        _policy_loss, policy_gradient = jax.value_and_grad(self._policy_loss_fn)(
            policy_params,
            critic_params,
            transitions,
        )
        # Compute gradient and update policies
        (
            policy_updates,
            policy_optimizer_state,
        ) = self._policies_optimizer.update(policy_gradient, policy_optimizer_state)
        policy_params = optax.apply_updates(policy_params, policy_updates)

        return policy_optimizer_state, policy_params
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

use_all_data: bool property readonly

Whether to use all data or not when used along other emitters.

QualityPGEmitter uses the transitions from the genotypes that were generated by other emitters.

init(self, init_genotypes, random_key)

Initializes the emitter state.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – The initial population.

  • random_key (Array) – A random key.

Returns:
  • Tuple[qdax.core.emitters.qpg_emitter.QualityPGEmitterState, jax.Array] – The initial state of the PGAMEEmitter, a new random key.

Source code in qdax/core/emitters/qpg_emitter.py
def init(
    self, init_genotypes: Genotype, random_key: RNGKey
) -> Tuple[QualityPGEmitterState, RNGKey]:
    """Initializes the emitter state.

    Args:
        init_genotypes: The initial population.
        random_key: A random key.

    Returns:
        The initial state of the PGAMEEmitter, a new random key.
    """

    observation_size = self._env.observation_size
    action_size = self._env.action_size
    descriptor_size = self._env.state_descriptor_length

    # Initialise critic, greedy actor and population
    random_key, subkey = jax.random.split(random_key)
    fake_obs = jnp.zeros(shape=(observation_size,))
    fake_action = jnp.zeros(shape=(action_size,))
    critic_params = self._critic_network.init(
        subkey, obs=fake_obs, actions=fake_action
    )
    target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params)

    actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)
    target_actor_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes)

    # Prepare init optimizer states
    critic_optimizer_state = self._critic_optimizer.init(critic_params)
    actor_optimizer_state = self._actor_optimizer.init(actor_params)

    # Initialize replay buffer
    dummy_transition = QDTransition.init_dummy(
        observation_dim=observation_size,
        action_dim=action_size,
        descriptor_dim=descriptor_size,
    )

    replay_buffer = ReplayBuffer.init(
        buffer_size=self._config.replay_buffer_size, transition=dummy_transition
    )

    # Initial training state
    random_key, subkey = jax.random.split(random_key)
    emitter_state = QualityPGEmitterState(
        critic_params=critic_params,
        critic_optimizer_state=critic_optimizer_state,
        actor_params=actor_params,
        actor_opt_state=actor_optimizer_state,
        target_critic_params=target_critic_params,
        target_actor_params=target_actor_params,
        random_key=subkey,
        steps=jnp.array(0),
        replay_buffer=replay_buffer,
    )

    return emitter_state, random_key
emit(self, repertoire, emitter_state, random_key)

Do a step of PG emission.

Parameters:
  • repertoire (Repertoire) – the current repertoire of genotypes

  • emitter_state (QualityPGEmitterState) – the state of the emitter used

  • random_key (Array) – a random key

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A batch of offspring, the new emitter state and a new key.

Source code in qdax/core/emitters/qpg_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: Repertoire,
    emitter_state: QualityPGEmitterState,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """Do a step of PG emission.

    Args:
        repertoire: the current repertoire of genotypes
        emitter_state: the state of the emitter used
        random_key: a random key

    Returns:
        A batch of offspring, the new emitter state and a new key.
    """

    batch_size = self._config.env_batch_size

    # sample parents
    mutation_pg_batch_size = int(batch_size - 1)
    parents, random_key = repertoire.sample(random_key, mutation_pg_batch_size)

    # apply the pg mutation
    offsprings_pg = self.emit_pg(emitter_state, parents)

    # get the actor (greedy actor)
    offspring_actor = self.emit_actor(emitter_state)

    # add dimension for concatenation
    offspring_actor = jax.tree_util.tree_map(
        lambda x: jnp.expand_dims(x, axis=0), offspring_actor
    )

    # gather offspring
    genotypes = jax.tree_util.tree_map(
        lambda x, y: jnp.concatenate([x, y], axis=0),
        offsprings_pg,
        offspring_actor,
    )

    return genotypes, random_key
emit_pg(self, emitter_state, parents)

Emit the offsprings generated through pg mutation.

Parameters:
  • emitter_state (QualityPGEmitterState) – current emitter state, contains critic and replay buffer.

  • parents (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the parents selected to be applied gradients in order to mutate towards better performance.

Returns:
  • Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]] – A new set of offsprings.

Source code in qdax/core/emitters/qpg_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit_pg(
    self, emitter_state: QualityPGEmitterState, parents: Genotype
) -> Genotype:
    """Emit the offsprings generated through pg mutation.

    Args:
        emitter_state: current emitter state, contains critic and
            replay buffer.
        parents: the parents selected to be applied gradients in order
            to mutate towards better performance.

    Returns:
        A new set of offsprings.
    """
    mutation_fn = partial(
        self._mutation_function_pg,
        emitter_state=emitter_state,
    )
    offsprings = jax.vmap(mutation_fn)(parents)

    return offsprings
emit_actor(self, emitter_state)

Emit the greedy actor.

Simply needs to be retrieved from the emitter state.

Parameters:
  • emitter_state (QualityPGEmitterState) – the current emitter state, it stores the greedy actor.

Returns:
  • Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]] – The parameters of the actor.

Source code in qdax/core/emitters/qpg_emitter.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit_actor(self, emitter_state: QualityPGEmitterState) -> Genotype:
    """Emit the greedy actor.

    Simply needs to be retrieved from the emitter state.

    Args:
        emitter_state: the current emitter state, it stores the
            greedy actor.

    Returns:
        The parameters of the actor.
    """
    return emitter_state.actor_params
state_update(self, emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

This function gives an opportunity to update the emitter state after the genotypes have been scored.

Here it is used to fill the Replay Buffer with the transitions from the scoring of the genotypes, and then the training of the critic/actor happens. Hence the params of critic/actor are updated, as well as their optimizer states.

Parameters:
  • emitter_state (QualityPGEmitterState) – current emitter state.

  • repertoire (Optional[qdax.core.containers.repertoire.Repertoire]) – the current genotypes repertoire

  • genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – unused here - but compulsory in the signature.

  • fitnesses (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • descriptors (Optional[jax.Array]) – unused here - but compulsory in the signature.

  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – extra information coming from the scoring function, this contains the transitions added to the replay buffer.

Returns:
  • QualityPGEmitterState – New emitter state where the replay buffer has been filled with the new experienced transitions.

Source code in qdax/core/emitters/qpg_emitter.py
@partial(jax.jit, static_argnames=("self",))
def state_update(
    self,
    emitter_state: QualityPGEmitterState,
    repertoire: Optional[Repertoire],
    genotypes: Optional[Genotype],
    fitnesses: Optional[Fitness],
    descriptors: Optional[Descriptor],
    extra_scores: ExtraScores,
) -> QualityPGEmitterState:
    """This function gives an opportunity to update the emitter state
    after the genotypes have been scored.

    Here it is used to fill the Replay Buffer with the transitions
    from the scoring of the genotypes, and then the training of the
    critic/actor happens. Hence the params of critic/actor are updated,
    as well as their optimizer states.

    Args:
        emitter_state: current emitter state.
        repertoire: the current genotypes repertoire
        genotypes: unused here - but compulsory in the signature.
        fitnesses: unused here - but compulsory in the signature.
        descriptors: unused here - but compulsory in the signature.
        extra_scores: extra information coming from the scoring function,
            this contains the transitions added to the replay buffer.

    Returns:
        New emitter state where the replay buffer has been filled with
        the new experienced transitions.
    """
    # get the transitions out of the dictionary
    assert "transitions" in extra_scores.keys(), "Missing transitions or wrong key"
    transitions = extra_scores["transitions"]

    # add transitions in the replay buffer
    replay_buffer = emitter_state.replay_buffer.insert(transitions)
    emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

    def scan_train_critics(
        carry: QualityPGEmitterState, unused: Any
    ) -> Tuple[QualityPGEmitterState, Any]:
        emitter_state = carry
        new_emitter_state = self._train_critics(emitter_state)
        return new_emitter_state, ()

    # Train critics and greedy actor
    emitter_state, _ = jax.lax.scan(
        scan_train_critics,
        emitter_state,
        (),
        length=self._config.num_critic_training_steps,
    )

    return emitter_state  # type: ignore

standard_emitters

MixingEmitter (Emitter)

Source code in qdax/core/emitters/standard_emitters.py
class MixingEmitter(Emitter):
    def __init__(
        self,
        mutation_fn: Callable[[Genotype, RNGKey], Tuple[Genotype, RNGKey]],
        variation_fn: Callable[[Genotype, Genotype, RNGKey], Tuple[Genotype, RNGKey]],
        variation_percentage: float,
        batch_size: int,
    ) -> None:
        self._mutation_fn = mutation_fn
        self._variation_fn = variation_fn
        self._variation_percentage = variation_percentage
        self._batch_size = batch_size

    @partial(
        jax.jit,
        static_argnames=("self",),
    )
    def emit(
        self,
        repertoire: Repertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[Genotype, RNGKey]:
        """
        Emitter that performs both mutation and variation. Two batches of
        variation_percentage * batch_size genotypes are sampled in the repertoire,
        copied and cross-over to obtain new offsprings. One batch of
        (1.0 - variation_percentage) * batch_size genotypes are sampled in the
        repertoire, copied and mutated.

        Note: this emitter has no state. A fake none state must be added
        through a function redefinition to make this emitter usable with MAP-Elites.

        Params:
            repertoire: the MAP-Elites repertoire to sample from
            emitter_state: void
            random_key: a jax PRNG random key

        Returns:
            a batch of offsprings
            a new jax PRNG key
        """
        n_variation = int(self._batch_size * self._variation_percentage)
        n_mutation = self._batch_size - n_variation

        if n_variation > 0:
            x1, random_key = repertoire.sample(random_key, n_variation)
            x2, random_key = repertoire.sample(random_key, n_variation)

            x_variation, random_key = self._variation_fn(x1, x2, random_key)

        if n_mutation > 0:
            x1, random_key = repertoire.sample(random_key, n_mutation)
            x_mutation, random_key = self._mutation_fn(x1, random_key)

        if n_variation == 0:
            genotypes = x_mutation
        elif n_mutation == 0:
            genotypes = x_variation
        else:
            genotypes = jax.tree_util.tree_map(
                lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
                x_variation,
                x_mutation,
            )

        return genotypes, random_key

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size
batch_size: int property readonly
Returns:
  • int – the batch size emitted by the emitter.

emit(self, repertoire, emitter_state, random_key)

Emitter that performs both mutation and variation. Two batches of variation_percentage * batch_size genotypes are sampled in the repertoire, copied and cross-over to obtain new offsprings. One batch of (1.0 - variation_percentage) * batch_size genotypes are sampled in the repertoire, copied and mutated.

Note: this emitter has no state. A fake none state must be added through a function redefinition to make this emitter usable with MAP-Elites.

Parameters:
  • repertoire (Repertoire) – the MAP-Elites repertoire to sample from

  • emitter_state (Optional[qdax.core.emitters.emitter.EmitterState]) – void

  • random_key (Array) – a jax PRNG random key

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – a batch of offsprings a new jax PRNG key

Source code in qdax/core/emitters/standard_emitters.py
@partial(
    jax.jit,
    static_argnames=("self",),
)
def emit(
    self,
    repertoire: Repertoire,
    emitter_state: Optional[EmitterState],
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    Emitter that performs both mutation and variation. Two batches of
    variation_percentage * batch_size genotypes are sampled in the repertoire,
    copied and cross-over to obtain new offsprings. One batch of
    (1.0 - variation_percentage) * batch_size genotypes are sampled in the
    repertoire, copied and mutated.

    Note: this emitter has no state. A fake none state must be added
    through a function redefinition to make this emitter usable with MAP-Elites.

    Params:
        repertoire: the MAP-Elites repertoire to sample from
        emitter_state: void
        random_key: a jax PRNG random key

    Returns:
        a batch of offsprings
        a new jax PRNG key
    """
    n_variation = int(self._batch_size * self._variation_percentage)
    n_mutation = self._batch_size - n_variation

    if n_variation > 0:
        x1, random_key = repertoire.sample(random_key, n_variation)
        x2, random_key = repertoire.sample(random_key, n_variation)

        x_variation, random_key = self._variation_fn(x1, x2, random_key)

    if n_mutation > 0:
        x1, random_key = repertoire.sample(random_key, n_mutation)
        x_mutation, random_key = self._mutation_fn(x1, random_key)

    if n_variation == 0:
        genotypes = x_mutation
    elif n_mutation == 0:
        genotypes = x_variation
    else:
        genotypes = jax.tree_util.tree_map(
            lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0),
            x_variation,
            x_mutation,
        )

    return genotypes, random_key