CMAES class

Class to run the CMA-ES algorithm.

Source code in qdax/baselines/cmaes.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
class CMAES:
    """
    Class to run the CMA-ES algorithm.
    """

    def __init__(
        self,
        population_size: int,
        search_dim: int,
        fitness_function: Callable[[Genotype], Fitness],
        num_best: Optional[int] = None,
        init_sigma: float = 1e-3,
        mean_init: Optional[jax.Array] = None,
        bias_weights: bool = True,
        delay_eigen_decomposition: bool = False,
    ):
        """Instantiate a CMA-ES optimizer.

        Args:
            population_size: size of the running population.
            search_dim: number of dimensions in the search space.
            fitness_function: fitness function that is being optimized.
            num_best: number of best individuals in the population being considered
                for the update of the distributions. Defaults to None.
            init_sigma: Initial value of the step size. Defaults to 1e-3.
            mean_init: Initial value of the distribution mean. Defaults to None.
            bias_weights: Should the weights be biased towards best individuals.
                Defaults to True.
            delay_eigen_decomposition: should the update of the inverse of the
                cov matrix be delayed. As this operation is a time bottleneck, having
                it delayed improves the time perfs by a significant margin.
                Defaults to False.
        """
        self._population_size = population_size
        self._search_dim = search_dim
        self._fitness_function = fitness_function
        self._init_sigma = init_sigma

        # Default values if values are not provided
        if num_best is None:
            self._num_best = population_size // 2
        else:
            self._num_best = num_best

        if mean_init is None:
            self._mean_init = jnp.zeros(shape=(search_dim,))
        else:
            self._mean_init = mean_init

        # weights parameters
        if bias_weights:
            # heuristic from Nicolas Hansen original implementation
            self._weights = jnp.log(
                (self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1))
            )
        else:
            self._weights = jnp.ones(self._num_best)

        # scale weights
        self._weights = self._weights / (self._weights.sum())
        self._parents_eff = 1 / (self._weights**2).sum()

        # adaptation  parameters
        self._c_s = (self._parents_eff + 2) / (self._search_dim + self._parents_eff + 5)
        self._c_c = (4 + self._parents_eff / self._search_dim) / (
            self._search_dim + 4 + 2 * self._parents_eff / self._search_dim
        )

        # learning rate for rank-1 update of C
        self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2)

        # learning rate for rank-(num best) updates
        tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff)
        self._c_cov = min(
            1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2)
        )

        # damping for sigma
        self._d_s = (
            1
            + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1)
            + self._c_s
        )
        self._chi = jnp.sqrt(self._search_dim) * (
            1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2)
        )

        # threshold for new eigen decomposition - from pyribs
        self._eigen_comput_period = 1
        if delay_eigen_decomposition:
            self._eigen_comput_period = (
                0.5
                * self._population_size
                / (self._search_dim * (self._c_1 + self._c_cov))
            )

    def init(self) -> CMAESState:
        """
        Init the CMA-ES algorithm.

        Returns:
            an initial state for the algorithm
        """

        # initial cov matrix
        cov_matrix = jnp.eye(self._search_dim)

        # initial inv sqrt of the cov matrix - cov is already diag
        invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix)))

        return CMAESState(
            mean=self._mean_init,
            cov_matrix=cov_matrix,
            sigma=self._init_sigma,
            num_updates=0,
            p_c=jnp.zeros(shape=(self._search_dim,)),
            p_s=jnp.zeros(shape=(self._search_dim,)),
            eigen_updates=0,
            eigenvalues=jnp.ones(shape=(self._search_dim,)),
            invsqrt_cov=invsqrt_cov,
        )

    def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
        """
        Sample a population.

        Args:
            cmaes_state: current state of the algorithm
            key: jax random key

        Returns:
            A tuple that contains a batch of population size genotypes and
            a new random key.
        """
        samples = jax.random.multivariate_normal(
            key,
            shape=(self._population_size,),
            mean=cmaes_state.mean,
            cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix,
        )
        return samples

    def update_state(
        self,
        cmaes_state: CMAESState,
        sorted_candidates: Genotype,
    ) -> CMAESState:
        return self._update_state(  # type: ignore
            cmaes_state=cmaes_state,
            sorted_candidates=sorted_candidates,
            weights=self._weights,
        )

    def update_state_with_mask(
        self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
    ) -> CMAESState:
        """Update weights with a mask, then update the state.

        Convention: 1 stays, 0 a removed.
        """

        # update weights by multiplying by a mask
        weights = jnp.multiply(self._weights, mask)
        weights = weights / (weights.sum())

        return self._update_state(  # type: ignore
            cmaes_state=cmaes_state,
            sorted_candidates=sorted_candidates,
            weights=weights,
        )

    def _update_state(
        self,
        cmaes_state: CMAESState,
        sorted_candidates: Genotype,
        weights: jax.Array,
    ) -> CMAESState:
        """Updates the state when candidates have already been
        sorted and selected.

        Args:
            cmaes_state: current state of the algorithm
            sorted_candidates: a batch of sorted and selected genotypes
            weights: weights used to recombine the candidates

        Returns:
            An updated algorithm state
        """

        # retrieve elements from the current state
        p_c = cmaes_state.p_c
        p_s = cmaes_state.p_s
        sigma = cmaes_state.sigma
        num_updates = cmaes_state.num_updates
        cov = cmaes_state.cov_matrix
        mean = cmaes_state.mean

        eigen_updates = cmaes_state.eigen_updates
        eigenvalues = cmaes_state.eigenvalues
        invsqrt_cov = cmaes_state.invsqrt_cov

        # update mean by recombination
        old_mean = mean
        mean = weights @ sorted_candidates

        def update_eigen(
            operand: Tuple[jax.Array, int]
        ) -> Tuple[int, jax.Array, jax.Array]:

            # unpack data
            cov, num_updates = operand

            # enforce symmetry - did not change anything
            cov = jnp.triu(cov) + jnp.triu(cov, 1).T

            # get eigen decomposition: eigenvalues, eigenvectors
            eig, u = jnp.linalg.eigh(cov)

            # compute new invsqrt
            invsqrt = u @ jnp.diag(1 / jnp.sqrt(eig)) @ u.T

            # update the eigen value decomposition tracker
            eigen_updates = num_updates

            return eigen_updates, eig, invsqrt

        # condition for recomputing the eig decomposition
        eigen_condition = (num_updates - eigen_updates) >= self._eigen_comput_period

        # decomposition of cov
        eigen_updates, eigenvalues, invsqrt = jax.lax.cond(
            eigen_condition,
            update_eigen,
            lambda _: (eigen_updates, eigenvalues, invsqrt_cov),
            operand=(cov, num_updates),
        )

        z = (1 / sigma) * (mean - old_mean)
        z_w = invsqrt @ z

        # update evolution paths - cumulation
        p_s = (1 - self._c_s) * p_s + jnp.sqrt(
            self._c_s * (2 - self._c_s) * self._parents_eff
        ) * z_w

        tmp_1 = jnp.linalg.norm(p_s) / jnp.sqrt(
            1 - (1 - self._c_s) ** (2 * num_updates)
        ) <= self._chi * (1.4 + 2 / (self._search_dim + 1))

        p_c = (1 - self._c_c) * p_c + tmp_1 * jnp.sqrt(
            self._c_c * (2 - self._c_c) * self._parents_eff
        ) * z

        # update covariance matrix
        pp_c = jnp.expand_dims(p_c, axis=1)

        coeff_tmp = (sorted_candidates - old_mean) / sigma
        cov_rank = coeff_tmp.T @ jnp.diag(weights.squeeze()) @ coeff_tmp

        cov = (
            (1 - self._c_cov - self._c_1) * cov
            + self._c_1
            * (pp_c @ pp_c.T + (1 - tmp_1) * self._c_c * (2 - self._c_c) * cov)
            + self._c_cov * cov_rank
        )

        # update step size
        sigma = sigma * jnp.exp(
            (self._c_s / self._d_s) * (jnp.linalg.norm(p_s) / self._chi - 1)
        )

        cmaes_state = CMAESState(
            mean=mean,
            cov_matrix=cov,
            sigma=sigma,
            num_updates=num_updates + 1,
            p_c=p_c,
            p_s=p_s,
            eigen_updates=eigen_updates,
            eigenvalues=eigenvalues,
            invsqrt_cov=invsqrt,
        )

        return cmaes_state

    def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
        """Updates the distribution.

        Args:
            cmaes_state: current state of the algorithm
            samples: a batch of genotypes

        Returns:
            an updated algorithm state
        """

        fitnesses = -self._fitness_function(samples)
        idx_sorted = jnp.argsort(fitnesses)
        sorted_candidates = samples[idx_sorted[: self._num_best]]

        new_state = self.update_state(cmaes_state, sorted_candidates)

        return new_state  # type: ignore

    def stop_condition(self, cmaes_state: CMAESState) -> bool:
        """Determines if the current optimization path must be stopped.

        A set of 5 conditions are computed, one condition is enough to
        stop the process. This function does not stop the process but simply
        retrieves the value. It is not called in the update function but can be
        used to manually stopped the process (see example in CMA ME emitter).

        Args:
            cmaes_state: current CMAES state

        Returns:
            A boolean stating if the process should be stopped.
        """

        # NaN appears because of float precision is reached
        nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0

        eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min(
            cmaes_state.eigenvalues
        )
        first_condition = eig_dispersion > 1e14

        area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues))
        second_condition = area < 1e-11

        third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7
        fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7

        return (  # type: ignore
            nan_condition
            + first_condition
            + second_condition
            + third_condition
            + fourth_condition
        )

__init__(population_size, search_dim, fitness_function, num_best=None, init_sigma=0.001, mean_init=None, bias_weights=True, delay_eigen_decomposition=False)

Instantiate a CMA-ES optimizer.

Parameters:
  • population_size (int) –

    size of the running population.

  • search_dim (int) –

    number of dimensions in the search space.

  • fitness_function (Callable[[Genotype], Fitness]) –

    fitness function that is being optimized.

  • num_best (Optional[int], default: None ) –

    number of best individuals in the population being considered for the update of the distributions. Defaults to None.

  • init_sigma (float, default: 0.001 ) –

    Initial value of the step size. Defaults to 1e-3.

  • mean_init (Optional[Array], default: None ) –

    Initial value of the distribution mean. Defaults to None.

  • bias_weights (bool, default: True ) –

    Should the weights be biased towards best individuals. Defaults to True.

  • delay_eigen_decomposition (bool, default: False ) –

    should the update of the inverse of the cov matrix be delayed. As this operation is a time bottleneck, having it delayed improves the time perfs by a significant margin. Defaults to False.

Source code in qdax/baselines/cmaes.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    population_size: int,
    search_dim: int,
    fitness_function: Callable[[Genotype], Fitness],
    num_best: Optional[int] = None,
    init_sigma: float = 1e-3,
    mean_init: Optional[jax.Array] = None,
    bias_weights: bool = True,
    delay_eigen_decomposition: bool = False,
):
    """Instantiate a CMA-ES optimizer.

    Args:
        population_size: size of the running population.
        search_dim: number of dimensions in the search space.
        fitness_function: fitness function that is being optimized.
        num_best: number of best individuals in the population being considered
            for the update of the distributions. Defaults to None.
        init_sigma: Initial value of the step size. Defaults to 1e-3.
        mean_init: Initial value of the distribution mean. Defaults to None.
        bias_weights: Should the weights be biased towards best individuals.
            Defaults to True.
        delay_eigen_decomposition: should the update of the inverse of the
            cov matrix be delayed. As this operation is a time bottleneck, having
            it delayed improves the time perfs by a significant margin.
            Defaults to False.
    """
    self._population_size = population_size
    self._search_dim = search_dim
    self._fitness_function = fitness_function
    self._init_sigma = init_sigma

    # Default values if values are not provided
    if num_best is None:
        self._num_best = population_size // 2
    else:
        self._num_best = num_best

    if mean_init is None:
        self._mean_init = jnp.zeros(shape=(search_dim,))
    else:
        self._mean_init = mean_init

    # weights parameters
    if bias_weights:
        # heuristic from Nicolas Hansen original implementation
        self._weights = jnp.log(
            (self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1))
        )
    else:
        self._weights = jnp.ones(self._num_best)

    # scale weights
    self._weights = self._weights / (self._weights.sum())
    self._parents_eff = 1 / (self._weights**2).sum()

    # adaptation  parameters
    self._c_s = (self._parents_eff + 2) / (self._search_dim + self._parents_eff + 5)
    self._c_c = (4 + self._parents_eff / self._search_dim) / (
        self._search_dim + 4 + 2 * self._parents_eff / self._search_dim
    )

    # learning rate for rank-1 update of C
    self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2)

    # learning rate for rank-(num best) updates
    tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff)
    self._c_cov = min(
        1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2)
    )

    # damping for sigma
    self._d_s = (
        1
        + 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1)
        + self._c_s
    )
    self._chi = jnp.sqrt(self._search_dim) * (
        1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2)
    )

    # threshold for new eigen decomposition - from pyribs
    self._eigen_comput_period = 1
    if delay_eigen_decomposition:
        self._eigen_comput_period = (
            0.5
            * self._population_size
            / (self._search_dim * (self._c_1 + self._c_cov))
        )

init()

Init the CMA-ES algorithm.

Returns:
  • CMAESState

    an initial state for the algorithm

Source code in qdax/baselines/cmaes.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def init(self) -> CMAESState:
    """
    Init the CMA-ES algorithm.

    Returns:
        an initial state for the algorithm
    """

    # initial cov matrix
    cov_matrix = jnp.eye(self._search_dim)

    # initial inv sqrt of the cov matrix - cov is already diag
    invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix)))

    return CMAESState(
        mean=self._mean_init,
        cov_matrix=cov_matrix,
        sigma=self._init_sigma,
        num_updates=0,
        p_c=jnp.zeros(shape=(self._search_dim,)),
        p_s=jnp.zeros(shape=(self._search_dim,)),
        eigen_updates=0,
        eigenvalues=jnp.ones(shape=(self._search_dim,)),
        invsqrt_cov=invsqrt_cov,
    )

sample(cmaes_state, key)

Sample a population.

Parameters:
  • cmaes_state (CMAESState) –

    current state of the algorithm

  • key (RNGKey) –

    jax random key

Returns:
  • Genotype

    A tuple that contains a batch of population size genotypes and

  • Genotype

    a new random key.

Source code in qdax/baselines/cmaes.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
    """
    Sample a population.

    Args:
        cmaes_state: current state of the algorithm
        key: jax random key

    Returns:
        A tuple that contains a batch of population size genotypes and
        a new random key.
    """
    samples = jax.random.multivariate_normal(
        key,
        shape=(self._population_size,),
        mean=cmaes_state.mean,
        cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix,
    )
    return samples

stop_condition(cmaes_state)

Determines if the current optimization path must be stopped.

A set of 5 conditions are computed, one condition is enough to stop the process. This function does not stop the process but simply retrieves the value. It is not called in the update function but can be used to manually stopped the process (see example in CMA ME emitter).

Parameters:
  • cmaes_state (CMAESState) –

    current CMAES state

Returns:
  • bool

    A boolean stating if the process should be stopped.

Source code in qdax/baselines/cmaes.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def stop_condition(self, cmaes_state: CMAESState) -> bool:
    """Determines if the current optimization path must be stopped.

    A set of 5 conditions are computed, one condition is enough to
    stop the process. This function does not stop the process but simply
    retrieves the value. It is not called in the update function but can be
    used to manually stopped the process (see example in CMA ME emitter).

    Args:
        cmaes_state: current CMAES state

    Returns:
        A boolean stating if the process should be stopped.
    """

    # NaN appears because of float precision is reached
    nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0

    eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min(
        cmaes_state.eigenvalues
    )
    first_condition = eig_dispersion > 1e14

    area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues))
    second_condition = area < 1e-11

    third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7
    fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7

    return (  # type: ignore
        nan_condition
        + first_condition
        + second_condition
        + third_condition
        + fourth_condition
    )

update(cmaes_state, samples)

Updates the distribution.

Parameters:
  • cmaes_state (CMAESState) –

    current state of the algorithm

  • samples (Genotype) –

    a batch of genotypes

Returns:
  • CMAESState

    an updated algorithm state

Source code in qdax/baselines/cmaes.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
    """Updates the distribution.

    Args:
        cmaes_state: current state of the algorithm
        samples: a batch of genotypes

    Returns:
        an updated algorithm state
    """

    fitnesses = -self._fitness_function(samples)
    idx_sorted = jnp.argsort(fitnesses)
    sorted_candidates = samples[idx_sorted[: self._num_best]]

    new_state = self.update_state(cmaes_state, sorted_candidates)

    return new_state  # type: ignore

update_state_with_mask(cmaes_state, sorted_candidates, mask)

Update weights with a mask, then update the state.

Convention: 1 stays, 0 a removed.

Source code in qdax/baselines/cmaes.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def update_state_with_mask(
    self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
) -> CMAESState:
    """Update weights with a mask, then update the state.

    Convention: 1 stays, 0 a removed.
    """

    # update weights by multiplying by a mask
    weights = jnp.multiply(self._weights, mask)
    weights = weights / (weights.sum())

    return self._update_state(  # type: ignore
        cmaes_state=cmaes_state,
        sorted_candidates=sorted_candidates,
        weights=weights,
    )