45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384 | class CMAES:
"""
Class to run the CMA-ES algorithm.
"""
def __init__(
self,
population_size: int,
search_dim: int,
fitness_function: Callable[[Genotype], Fitness],
num_best: Optional[int] = None,
init_sigma: float = 1e-3,
mean_init: Optional[jax.Array] = None,
bias_weights: bool = True,
delay_eigen_decomposition: bool = False,
):
"""Instantiate a CMA-ES optimizer.
Args:
population_size: size of the running population.
search_dim: number of dimensions in the search space.
fitness_function: fitness function that is being optimized.
num_best: number of best individuals in the population being considered
for the update of the distributions. Defaults to None.
init_sigma: Initial value of the step size. Defaults to 1e-3.
mean_init: Initial value of the distribution mean. Defaults to None.
bias_weights: Should the weights be biased towards best individuals.
Defaults to True.
delay_eigen_decomposition: should the update of the inverse of the
cov matrix be delayed. As this operation is a time bottleneck, having
it delayed improves the time perfs by a significant margin.
Defaults to False.
"""
self._population_size = population_size
self._search_dim = search_dim
self._fitness_function = fitness_function
self._init_sigma = init_sigma
# Default values if values are not provided
if num_best is None:
self._num_best = population_size // 2
else:
self._num_best = num_best
if mean_init is None:
self._mean_init = jnp.zeros(shape=(search_dim,))
else:
self._mean_init = mean_init
# weights parameters
if bias_weights:
# heuristic from Nicolas Hansen original implementation
self._weights = jnp.log(
(self._num_best + 0.5) / jnp.arange(start=1, stop=(self._num_best + 1))
)
else:
self._weights = jnp.ones(self._num_best)
# scale weights
self._weights = self._weights / (self._weights.sum())
self._parents_eff = 1 / (self._weights**2).sum()
# adaptation parameters
self._c_s = (self._parents_eff + 2) / (self._search_dim + self._parents_eff + 5)
self._c_c = (4 + self._parents_eff / self._search_dim) / (
self._search_dim + 4 + 2 * self._parents_eff / self._search_dim
)
# learning rate for rank-1 update of C
self._c_1 = 2 / (self._parents_eff + (self._search_dim + jnp.sqrt(2)) ** 2)
# learning rate for rank-(num best) updates
tmp = 2 * (self._parents_eff - 2 + 1 / self._parents_eff)
self._c_cov = min(
1 - self._c_1, tmp / (self._parents_eff + (self._search_dim + 2) ** 2)
)
# damping for sigma
self._d_s = (
1
+ 2 * max(0, jnp.sqrt((self._parents_eff - 1) / (self._search_dim + 1)) - 1)
+ self._c_s
)
self._chi = jnp.sqrt(self._search_dim) * (
1 - 1 / (4 * self._search_dim) + 1 / (21 * self._search_dim**2)
)
# threshold for new eigen decomposition - from pyribs
self._eigen_comput_period = 1
if delay_eigen_decomposition:
self._eigen_comput_period = (
0.5
* self._population_size
/ (self._search_dim * (self._c_1 + self._c_cov))
)
def init(self) -> CMAESState:
"""
Init the CMA-ES algorithm.
Returns:
an initial state for the algorithm
"""
# initial cov matrix
cov_matrix = jnp.eye(self._search_dim)
# initial inv sqrt of the cov matrix - cov is already diag
invsqrt_cov = jnp.diag(1 / jnp.sqrt(jnp.diag(cov_matrix)))
return CMAESState(
mean=self._mean_init,
cov_matrix=cov_matrix,
sigma=self._init_sigma,
num_updates=0,
p_c=jnp.zeros(shape=(self._search_dim,)),
p_s=jnp.zeros(shape=(self._search_dim,)),
eigen_updates=0,
eigenvalues=jnp.ones(shape=(self._search_dim,)),
invsqrt_cov=invsqrt_cov,
)
def sample(self, cmaes_state: CMAESState, key: RNGKey) -> Genotype:
"""
Sample a population.
Args:
cmaes_state: current state of the algorithm
key: jax random key
Returns:
A tuple that contains a batch of population size genotypes and
a new random key.
"""
samples = jax.random.multivariate_normal(
key,
shape=(self._population_size,),
mean=cmaes_state.mean,
cov=(cmaes_state.sigma**2) * cmaes_state.cov_matrix,
)
return samples
def update_state(
self,
cmaes_state: CMAESState,
sorted_candidates: Genotype,
) -> CMAESState:
return self._update_state( # type: ignore
cmaes_state=cmaes_state,
sorted_candidates=sorted_candidates,
weights=self._weights,
)
def update_state_with_mask(
self, cmaes_state: CMAESState, sorted_candidates: Genotype, mask: Mask
) -> CMAESState:
"""Update weights with a mask, then update the state.
Convention: 1 stays, 0 a removed.
"""
# update weights by multiplying by a mask
weights = jnp.multiply(self._weights, mask)
weights = weights / (weights.sum())
return self._update_state( # type: ignore
cmaes_state=cmaes_state,
sorted_candidates=sorted_candidates,
weights=weights,
)
def _update_state(
self,
cmaes_state: CMAESState,
sorted_candidates: Genotype,
weights: jax.Array,
) -> CMAESState:
"""Updates the state when candidates have already been
sorted and selected.
Args:
cmaes_state: current state of the algorithm
sorted_candidates: a batch of sorted and selected genotypes
weights: weights used to recombine the candidates
Returns:
An updated algorithm state
"""
# retrieve elements from the current state
p_c = cmaes_state.p_c
p_s = cmaes_state.p_s
sigma = cmaes_state.sigma
num_updates = cmaes_state.num_updates
cov = cmaes_state.cov_matrix
mean = cmaes_state.mean
eigen_updates = cmaes_state.eigen_updates
eigenvalues = cmaes_state.eigenvalues
invsqrt_cov = cmaes_state.invsqrt_cov
# update mean by recombination
old_mean = mean
mean = weights @ sorted_candidates
def update_eigen(
operand: Tuple[jax.Array, int]
) -> Tuple[int, jax.Array, jax.Array]:
# unpack data
cov, num_updates = operand
# enforce symmetry - did not change anything
cov = jnp.triu(cov) + jnp.triu(cov, 1).T
# get eigen decomposition: eigenvalues, eigenvectors
eig, u = jnp.linalg.eigh(cov)
# compute new invsqrt
invsqrt = u @ jnp.diag(1 / jnp.sqrt(eig)) @ u.T
# update the eigen value decomposition tracker
eigen_updates = num_updates
return eigen_updates, eig, invsqrt
# condition for recomputing the eig decomposition
eigen_condition = (num_updates - eigen_updates) >= self._eigen_comput_period
# decomposition of cov
eigen_updates, eigenvalues, invsqrt = jax.lax.cond(
eigen_condition,
update_eigen,
lambda _: (eigen_updates, eigenvalues, invsqrt_cov),
operand=(cov, num_updates),
)
z = (1 / sigma) * (mean - old_mean)
z_w = invsqrt @ z
# update evolution paths - cumulation
p_s = (1 - self._c_s) * p_s + jnp.sqrt(
self._c_s * (2 - self._c_s) * self._parents_eff
) * z_w
tmp_1 = jnp.linalg.norm(p_s) / jnp.sqrt(
1 - (1 - self._c_s) ** (2 * num_updates)
) <= self._chi * (1.4 + 2 / (self._search_dim + 1))
p_c = (1 - self._c_c) * p_c + tmp_1 * jnp.sqrt(
self._c_c * (2 - self._c_c) * self._parents_eff
) * z
# update covariance matrix
pp_c = jnp.expand_dims(p_c, axis=1)
coeff_tmp = (sorted_candidates - old_mean) / sigma
cov_rank = coeff_tmp.T @ jnp.diag(weights.squeeze()) @ coeff_tmp
cov = (
(1 - self._c_cov - self._c_1) * cov
+ self._c_1
* (pp_c @ pp_c.T + (1 - tmp_1) * self._c_c * (2 - self._c_c) * cov)
+ self._c_cov * cov_rank
)
# update step size
sigma = sigma * jnp.exp(
(self._c_s / self._d_s) * (jnp.linalg.norm(p_s) / self._chi - 1)
)
cmaes_state = CMAESState(
mean=mean,
cov_matrix=cov,
sigma=sigma,
num_updates=num_updates + 1,
p_c=p_c,
p_s=p_s,
eigen_updates=eigen_updates,
eigenvalues=eigenvalues,
invsqrt_cov=invsqrt,
)
return cmaes_state
def update(self, cmaes_state: CMAESState, samples: Genotype) -> CMAESState:
"""Updates the distribution.
Args:
cmaes_state: current state of the algorithm
samples: a batch of genotypes
Returns:
an updated algorithm state
"""
fitnesses = -self._fitness_function(samples)
idx_sorted = jnp.argsort(fitnesses)
sorted_candidates = samples[idx_sorted[: self._num_best]]
new_state = self.update_state(cmaes_state, sorted_candidates)
return new_state # type: ignore
def stop_condition(self, cmaes_state: CMAESState) -> bool:
"""Determines if the current optimization path must be stopped.
A set of 5 conditions are computed, one condition is enough to
stop the process. This function does not stop the process but simply
retrieves the value. It is not called in the update function but can be
used to manually stopped the process (see example in CMA ME emitter).
Args:
cmaes_state: current CMAES state
Returns:
A boolean stating if the process should be stopped.
"""
# NaN appears because of float precision is reached
nan_condition = jnp.sum(jnp.isnan(cmaes_state.eigenvalues)) > 0
eig_dispersion = jnp.max(cmaes_state.eigenvalues) / jnp.min(
cmaes_state.eigenvalues
)
first_condition = eig_dispersion > 1e14
area = cmaes_state.sigma * jnp.sqrt(jnp.max(cmaes_state.eigenvalues))
second_condition = area < 1e-11
third_condition = jnp.max(cmaes_state.eigenvalues) < 1e-7
fourth_condition = jnp.min(cmaes_state.eigenvalues) > 1e7
return ( # type: ignore
nan_condition
+ first_condition
+ second_condition
+ third_condition
+ fourth_condition
)
|