CMAES class

qdax.core.cmaes.CMAES

Class to run the CMA-ES algorithm.

Source code in qdax/core/cmaes.py
class CMAES:
    """
    Class to run the CMA-ES algorithm.
    """

    def __init__(
        self,
        population_size: int,
        search_dim: int,
        fitness_function: Callable[[Genotype], Fitness],
        num_best: Optional[int] = None,
        init_sigma: float = 1e-3,
        mean_init: Optional[jnp.ndarray] = None,
        bias_weights: bool = True,
        delay_eigen_decomposition: bool = False,
    ):
        """Instantiate a CMA-ES optimizer.

        Args:
            population_size: size of the running population.
            search_dim: number of dimensions in the search space.
            fitness_function: fitness function that is being optimized.
            num_best: number of best individuals in the population being considered
                for the update of the distributions. Defaults to None.
            init_sigma: Initial value of the step size. Defaults to 1e-3.
            mean_init: Initial value of the distribution mean. Defaults to None.
            bias_weights: Should the weights be biased towards best individuals.
                Defaults to True.
            delay_eigen_decomposition: should the update of the inverse of the
                cov matrix be delayed. As this operation is a time bottleneck, having
                it delayed improves the time perfs by a significant margin.
                Defaults to False.
        """
        self._population_size = population_size
        self._search_dim = search_dim
        self._fitness_function = fitness_function
        self._init_sigma = init_sigma

        # Default values if values are not provided
        if num_best is None:
            self._num_best = population_size // 2
        else:
            self._num_best = num_best

        if mean_init is None:
            self._mean_init = jnp.zeros(shape=(search_dim,))
        else:
            self._mean_init = mean_init

        # weights parameters
        if bias_weights:
            # heuristic from Nicolas Hansen original implementation
            self._weights = jnp.log(
                (self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1))
            )
        else:
            self._weights = jnp.ones(self._num_best)

        # scale weights
        self._weights = self._weights / (self._weights.sum())
        self._parents_eff = 1 / (self._weights**2).sum()

        # adaptation  parameters
        self._c_s = (self._parents_eff + 2) / (self._search_dim + self._parents_eff + 5)
        self._c_c = (4 + self._parents_eff / self._search_dim) / (
            self._search_dim + 4 + 2 * self._parents_eff / self._search_dim
        )

        # learning rate for rank-1 update of C
        self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2)

        # learning rate for rank-(num best) updates
        tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff)
        self._c_cov = min(
            1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2)
        )

        # damping for sigma
        self._d_s = (
            1
            + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1)
            + self._c_s
        )
        self._chi = jnp.sqrt(self._search_dim) * (
            1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2)
        )

        # threshold for new eigen decomposition - from pyribs
        self._eigen_comput_period = 1
        if delay_eigen_decomposition:
            self._eigen_comput_period = (
                0.5
                * self._population_size
                / (self._search_dim * (self._c_1 + self._c_cov))
            )

    def init(self) -> CMAESState:
        """
        Init the CMA-ES algorithm.

        Returns:
            an initial state for the algorithm
        """

        # initial cov matrix
        cov_matrix = jnp.eye(self._search_dim)

        # initial inv sqrt of the cov matrix - cov is already diag
        invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix)))

        return CMAESState(
            mean=self._mean_init,
            cov_matrix=cov_matrix,
            sigma=self._init_sigma,
            num_updates=0,
            p_c=jnp.zeros(shape=(self._search_dim,)),
            p_s=jnp.zeros(shape=(self._search_dim,)),
            eigen_updates=0,
            eigenvalues=jnp.ones(shape=(self._search_dim,)),
            invsqrt_cov=invsqrt_cov,
        )

    @partial(jax.jit, static_argnames=("self",))
    def sample(
        self, cmaes_state: CMAESState, random_key: RNGKey
    ) -> Tuple[Genotype, RNGKey]:
        """
        Sample a population.

        Args:
            cmaes_state: current state of the algorithm
            random_key: jax random key

        Returns:
            A tuple that contains a batch of population size genotypes and
            a new random key.
        """
        random_key, subkey = jax.random.split(random_key)
        samples = jax.random.multivariate_normal(
            subkey,
            shape=(self._population_size,),
            mean=cmaes_state.mean,
            cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix,
        )
        return samples, random_key

    @partial(jax.jit, static_argnames=("self",))
    def update_state(
        self,
        cmaes_state: CMAESState,
        sorted_candidates: Genotype,
    ) -> CMAESState:
        return self._update_state(  # type: ignore
            cmaes_state=cmaes_state,
            sorted_candidates=sorted_candidates,
            weights=self._weights,
        )

    @partial(jax.jit, static_argnames=("self",))
    def update_state_with_mask(
        self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
    ) -> CMAESState:
        """Update weights with a mask, then update the state.

        Convention: 1 stays, 0 a removed.
        """

        # update weights by multiplying by a mask
        weights = jnp.multiply(self._weights, mask)
        weights = weights / (weights.sum())

        return self._update_state(  # type: ignore
            cmaes_state=cmaes_state,
            sorted_candidates=sorted_candidates,
            weights=weights,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _update_state(
        self,
        cmaes_state: CMAESState,
        sorted_candidates: Genotype,
        weights: jnp.ndarray,
    ) -> CMAESState:
        """Updates the state when candidates have already been
        sorted and selected.

        Args:
            cmaes_state: current state of the algorithm
            sorted_candidates: a batch of sorted and selected genotypes
            weights: weights used to recombine the candidates

        Returns:
            An updated algorithm state
        """

        # retrieve elements from the current state
        p_c = cmaes_state.p_c
        p_s = cmaes_state.p_s
        sigma = cmaes_state.sigma
        num_updates = cmaes_state.num_updates
        cov = cmaes_state.cov_matrix
        mean = cmaes_state.mean

        eigen_updates = cmaes_state.eigen_updates
        eigenvalues = cmaes_state.eigenvalues
        invsqrt_cov = cmaes_state.invsqrt_cov

        # update mean by recombination
        old_mean = mean
        mean = weights @ sorted_candidates

        def update_eigen(
            operand: Tuple[jnp.ndarray, int]
        ) -> Tuple[int, jnp.ndarray, jnp.ndarray]:

            # unpack data
            cov, num_updates = operand

            # enfore symmetry - did not change anything
            cov = jnp.triu(cov) + jnp.triu(cov, 1).T

            # get eigen decomposition: eigenvalues, eigenvectors
            eig, u = jnp.linalg.eigh(cov)

            # compute new invsqrt
            invsqrt = u @ jnp.diag(1 / jnp.sqrt(eig)) @ u.T

            # update the eigen value decomposition tracker
            eigen_updates = num_updates

            return eigen_updates, eig, invsqrt

        # condition for recomputing the eig decomposition
        eigen_condition = (num_updates - eigen_updates) >= self._eigen_comput_period

        # decomposition of cov
        eigen_updates, eigenvalues, invsqrt = jax.lax.cond(
            eigen_condition,
            update_eigen,
            lambda _: (eigen_updates, eigenvalues, invsqrt_cov),
            operand=(cov, num_updates),
        )

        z = (1 / sigma) * (mean - old_mean)
        z_w = invsqrt @ z

        # update evolution paths - cumulation
        p_s = (1 - self._c_s) * p_s + jnp.sqrt(
            self._c_s * (2 - self._c_s) * self._parents_eff
        ) * z_w

        tmp_1 = jnp.linalg.norm(p_s) / jnp.sqrt(
            1 - (1 - self._c_s) ** (2 * num_updates)
        ) <= self._chi * (1.4 + 2 / (self._search_dim + 1))

        p_c = (1 - self._c_c) * p_c + tmp_1 * jnp.sqrt(
            self._c_c * (2 - self._c_c) * self._parents_eff
        ) * z

        # update covariance matrix
        pp_c = jnp.expand_dims(p_c, axis=1)

        coeff_tmp = (sorted_candidates - old_mean) / sigma
        cov_rank = coeff_tmp.T @ jnp.diag(weights.squeeze()) @ coeff_tmp

        cov = (
            (1 - self._c_cov - self._c_1) * cov
            + self._c_1
            * (pp_c @ pp_c.T + (1 - tmp_1) * self._c_c * (2 - self._c_c) * cov)
            + self._c_cov * cov_rank
        )

        # update step size
        sigma = sigma * jnp.exp(
            (self._c_s / self._d_s) * (jnp.linalg.norm(p_s) / self._chi - 1)
        )

        cmaes_state = CMAESState(
            mean=mean,
            cov_matrix=cov,
            sigma=sigma,
            num_updates=num_updates + 1,
            p_c=p_c,
            p_s=p_s,
            eigen_updates=eigen_updates,
            eigenvalues=eigenvalues,
            invsqrt_cov=invsqrt,
        )

        return cmaes_state

    @partial(jax.jit, static_argnames=("self",))
    def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
        """Updates the distribution.

        Args:
            cmaes_state: current state of the algorithm
            samples: a batch of genotypes

        Returns:
            an updated algorithm state
        """

        fitnesses = -self._fitness_function(samples)
        idx_sorted = jnp.argsort(fitnesses)
        sorted_candidates = samples[idx_sorted[: self._num_best]]

        new_state = self.update_state(cmaes_state, sorted_candidates)

        return new_state  # type: ignore

    @partial(jax.jit, static_argnames=("self",))
    def stop_condition(self, cmaes_state: CMAESState) -> bool:
        """Determines if the current optimization path must be stopped.

        A set of 5 conditions are computed, one condition is enough to
        stop the process. This function does not stop the process but simply
        retrieves the value. It is not called in the update function but can be
        used to manually stopped the process (see example in CMA ME emitter).

        Args:
            cmaes_state: current CMAES state

        Returns:
            A boolean stating if the process should be stopped.
        """

        # NaN appears because of float precision is reached
        nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0

        eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min(
            cmaes_state.eigenvalues
        )
        first_condition = eig_dispersion > 1e14

        area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues))
        second_condition = area < 1e-11

        third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7
        fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7

        return (  # type: ignore
            nan_condition
            + first_condition
            + second_condition
            + third_condition
            + fourth_condition
        )

__init__(self, population_size, search_dim, fitness_function, num_best=None, init_sigma=0.001, mean_init=None, bias_weights=True, delay_eigen_decomposition=False) special

Instantiate a CMA-ES optimizer.

Parameters:
  • population_size (int) – size of the running population.

  • search_dim (int) – number of dimensions in the search space.

  • fitness_function (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array]) – fitness function that is being optimized.

  • num_best (Optional[int]) – number of best individuals in the population being considered for the update of the distributions. Defaults to None.

  • init_sigma (float) – Initial value of the step size. Defaults to 1e-3.

  • mean_init (Optional[jax.Array]) – Initial value of the distribution mean. Defaults to None.

  • bias_weights (bool) – Should the weights be biased towards best individuals. Defaults to True.

  • delay_eigen_decomposition (bool) – should the update of the inverse of the cov matrix be delayed. As this operation is a time bottleneck, having it delayed improves the time perfs by a significant margin. Defaults to False.

Source code in qdax/core/cmaes.py
def __init__(
    self,
    population_size: int,
    search_dim: int,
    fitness_function: Callable[[Genotype], Fitness],
    num_best: Optional[int] = None,
    init_sigma: float = 1e-3,
    mean_init: Optional[jnp.ndarray] = None,
    bias_weights: bool = True,
    delay_eigen_decomposition: bool = False,
):
    """Instantiate a CMA-ES optimizer.

    Args:
        population_size: size of the running population.
        search_dim: number of dimensions in the search space.
        fitness_function: fitness function that is being optimized.
        num_best: number of best individuals in the population being considered
            for the update of the distributions. Defaults to None.
        init_sigma: Initial value of the step size. Defaults to 1e-3.
        mean_init: Initial value of the distribution mean. Defaults to None.
        bias_weights: Should the weights be biased towards best individuals.
            Defaults to True.
        delay_eigen_decomposition: should the update of the inverse of the
            cov matrix be delayed. As this operation is a time bottleneck, having
            it delayed improves the time perfs by a significant margin.
            Defaults to False.
    """
    self._population_size = population_size
    self._search_dim = search_dim
    self._fitness_function = fitness_function
    self._init_sigma = init_sigma

    # Default values if values are not provided
    if num_best is None:
        self._num_best = population_size // 2
    else:
        self._num_best = num_best

    if mean_init is None:
        self._mean_init = jnp.zeros(shape=(search_dim,))
    else:
        self._mean_init = mean_init

    # weights parameters
    if bias_weights:
        # heuristic from Nicolas Hansen original implementation
        self._weights = jnp.log(
            (self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1))
        )
    else:
        self._weights = jnp.ones(self._num_best)

    # scale weights
    self._weights = self._weights / (self._weights.sum())
    self._parents_eff = 1 / (self._weights**2).sum()

    # adaptation  parameters
    self._c_s = (self._parents_eff + 2) / (self._search_dim + self._parents_eff + 5)
    self._c_c = (4 + self._parents_eff / self._search_dim) / (
        self._search_dim + 4 + 2 * self._parents_eff / self._search_dim
    )

    # learning rate for rank-1 update of C
    self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2)

    # learning rate for rank-(num best) updates
    tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff)
    self._c_cov = min(
        1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2)
    )

    # damping for sigma
    self._d_s = (
        1
        + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1)
        + self._c_s
    )
    self._chi = jnp.sqrt(self._search_dim) * (
        1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2)
    )

    # threshold for new eigen decomposition - from pyribs
    self._eigen_comput_period = 1
    if delay_eigen_decomposition:
        self._eigen_comput_period = (
            0.5
            * self._population_size
            / (self._search_dim * (self._c_1 + self._c_cov))
        )

init(self)

Init the CMA-ES algorithm.

Returns:
  • CMAESState – an initial state for the algorithm

Source code in qdax/core/cmaes.py
def init(self) -> CMAESState:
    """
    Init the CMA-ES algorithm.

    Returns:
        an initial state for the algorithm
    """

    # initial cov matrix
    cov_matrix = jnp.eye(self._search_dim)

    # initial inv sqrt of the cov matrix - cov is already diag
    invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix)))

    return CMAESState(
        mean=self._mean_init,
        cov_matrix=cov_matrix,
        sigma=self._init_sigma,
        num_updates=0,
        p_c=jnp.zeros(shape=(self._search_dim,)),
        p_s=jnp.zeros(shape=(self._search_dim,)),
        eigen_updates=0,
        eigenvalues=jnp.ones(shape=(self._search_dim,)),
        invsqrt_cov=invsqrt_cov,
    )

sample(self, cmaes_state, random_key)

Sample a population.

Parameters:
  • cmaes_state (CMAESState) – current state of the algorithm

  • random_key (Array) – jax random key

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A tuple that contains a batch of population size genotypes and a new random key.

Source code in qdax/core/cmaes.py
@partial(jax.jit, static_argnames=("self",))
def sample(
    self, cmaes_state: CMAESState, random_key: RNGKey
) -> Tuple[Genotype, RNGKey]:
    """
    Sample a population.

    Args:
        cmaes_state: current state of the algorithm
        random_key: jax random key

    Returns:
        A tuple that contains a batch of population size genotypes and
        a new random key.
    """
    random_key, subkey = jax.random.split(random_key)
    samples = jax.random.multivariate_normal(
        subkey,
        shape=(self._population_size,),
        mean=cmaes_state.mean,
        cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix,
    )
    return samples, random_key

update_state_with_mask(self, cmaes_state, sorted_candidates, mask)

Update weights with a mask, then update the state.

Convention: 1 stays, 0 a removed.

Source code in qdax/core/cmaes.py
@partial(jax.jit, static_argnames=("self",))
def update_state_with_mask(
    self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
) -> CMAESState:
    """Update weights with a mask, then update the state.

    Convention: 1 stays, 0 a removed.
    """

    # update weights by multiplying by a mask
    weights = jnp.multiply(self._weights, mask)
    weights = weights / (weights.sum())

    return self._update_state(  # type: ignore
        cmaes_state=cmaes_state,
        sorted_candidates=sorted_candidates,
        weights=weights,
    )

update(self, cmaes_state, samples)

Updates the distribution.

Parameters:
  • cmaes_state (CMAESState) – current state of the algorithm

  • samples (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – a batch of genotypes

Returns:
  • CMAESState – an updated algorithm state

Source code in qdax/core/cmaes.py
@partial(jax.jit, static_argnames=("self",))
def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
    """Updates the distribution.

    Args:
        cmaes_state: current state of the algorithm
        samples: a batch of genotypes

    Returns:
        an updated algorithm state
    """

    fitnesses = -self._fitness_function(samples)
    idx_sorted = jnp.argsort(fitnesses)
    sorted_candidates = samples[idx_sorted[: self._num_best]]

    new_state = self.update_state(cmaes_state, sorted_candidates)

    return new_state  # type: ignore

stop_condition(self, cmaes_state)

Determines if the current optimization path must be stopped.

A set of 5 conditions are computed, one condition is enough to stop the process. This function does not stop the process but simply retrieves the value. It is not called in the update function but can be used to manually stopped the process (see example in CMA ME emitter).

Parameters:
  • cmaes_state (CMAESState) – current CMAES state

Returns:
  • bool – A boolean stating if the process should be stopped.

Source code in qdax/core/cmaes.py
@partial(jax.jit, static_argnames=("self",))
def stop_condition(self, cmaes_state: CMAESState) -> bool:
    """Determines if the current optimization path must be stopped.

    A set of 5 conditions are computed, one condition is enough to
    stop the process. This function does not stop the process but simply
    retrieves the value. It is not called in the update function but can be
    used to manually stopped the process (see example in CMA ME emitter).

    Args:
        cmaes_state: current CMAES state

    Returns:
        A boolean stating if the process should be stopped.
    """

    # NaN appears because of float precision is reached
    nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0

    eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min(
        cmaes_state.eigenvalues
    )
    first_condition = eig_dispersion > 1e14

    area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues))
    second_condition = area < 1e-11

    third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7
    fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7

    return (  # type: ignore
        nan_condition
        + first_condition
        + second_condition
        + third_condition
        + fourth_condition
    )