31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263 | class OMGMEGAEmitter(Emitter):
"""
Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by
Fontaine et al.
NOTE: in order to implement this emitter while staying in the MAPElites
framework, we had to make two temporary design choices:
- in the emit function, we use the same random key to sample from the
genotypes and gradients repertoire, in order to get the gradients that
correspond to the right genotypes. Although acceptable, this is definitely
not the best coding practice and we would prefer to get rid of this in a
future version. A solution that we are discussing with the development team
is to decompose the sampling function of the repertoire into two phases: one
sampling the indices to be sampled, the other one retrieving the corresponding
elements. This would enable to reuse the indices instead of doing this double
sampling.
- in the state_update, we have to insert the gradients in the gradients
repertoire in the same way the individuals were inserted. Once again, this is
slightly unoptimal because the same addition mechanism has to be computed two
times. One solution that we are discussing and that is very similar to the first
solution discussed above, would be to decompose the addition mechanism in two
phases: one outputting the indices at which individuals will be added, and then
the actual insertion step. This would enable to reuse the same indices to add
the gradients instead of having to recompute them.
The two design choices seem acceptable and enable to have OMG MEGA compatible
with the current implementation of the MAPElites and MAPElitesRepertoire classes.
Our suggested solutions seem quite simple and are likely to be useful for other
variants implementation. They will be further discussed with the development team
and potentially added in a future version of the package.
"""
def __init__(
self,
batch_size: int,
sigma_g: float,
num_descriptors: int,
centroids: Centroid,
selector: Optional[Selector] = None,
):
"""Creates an instance of the OMGMEGAEmitter class.
Args:
batch_size: number of solutions sampled at each iteration
sigma_g: CAUTION - square of the standard deviation for the coefficients.
This notation can be misleading as, although it's called sigma, it
refers to the variance and not the standard deviation.
num_descriptors: number of descriptors
centroids: centroids used to create the repertoire of solutions.
This will be used to create the repertoire of gradients.
"""
# set the mean of the coeff distribution to zero
self._mu = jnp.zeros(num_descriptors + 1)
# set the cov matrix to sigma * I
self._sigma = jnp.eye(num_descriptors + 1) * sigma_g
# define other parameters of the distribution
self._batch_size = batch_size
self._centroids = centroids
self._num_descriptors = num_descriptors
self._selector = selector
def init(
self,
key: RNGKey,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
"""Initialises the state of the emitter. Creates an empty repertoire
that will later contain the gradients of the individuals.
Args:
genotypes: The genotypes of the initial population.
key: a random key to handle stochastic operations.
Returns:
The initial emitter state.
"""
# retrieve one genotype from the population
first_genotype = jax.tree.map(lambda x: x[0], genotypes)
# add a dimension of size num descriptors + 1
gradient_genotype = jax.tree.map(
lambda x: jnp.repeat(
jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
),
first_genotype,
)
# create the gradients repertoire
gradients_repertoire = MapElitesRepertoire.init_default(
genotype=gradient_genotype, centroids=self._centroids
)
# get gradients out of the extra scores
assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
gradients = extra_scores["gradients"]
# update the gradients repertoire
gradients_repertoire = gradients_repertoire.add(
gradients,
descriptors,
fitnesses,
extra_scores,
)
return OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire)
def emit( # type: ignore
self,
repertoire: MapElitesRepertoire,
emitter_state: OMGMEGAEmitterState,
key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
"""
OMG emitter function that samples elements in the repertoire and does a gradient
update with random coefficients to create new candidates.
Args:
repertoire: current repertoire
emitter_state: current emitter state, contains the gradients
key: random key
Returns:
new_genotypes: new candidates to be added to the grid
"""
# sample genotypes
key, subkey = jax.random.split(key)
size_repertoire = repertoire.fitnesses.shape[0]
repertoire_indexes = repertoire.replace(genotypes=jnp.arange(size_repertoire))
indexes_selected = repertoire_indexes.select(
subkey, num_samples=self._batch_size, selector=self._selector
).genotypes
genotypes = jax.tree.map(lambda x: x[indexes_selected], repertoire.genotypes)
gradients = jax.tree.map(
lambda x: x[indexes_selected], emitter_state.gradients_repertoire.genotypes
)
fitness_gradients = jax.tree.map(
lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
)
descriptors_gradients = jax.tree.map(lambda x: x[:, :, 1:], gradients)
# Normalize the gradients
norm_fitness_gradients = jnp.linalg.norm(
fitness_gradients, axis=1, keepdims=True
)
fitness_gradients = fitness_gradients / norm_fitness_gradients
norm_descriptors_gradients = jnp.linalg.norm(
descriptors_gradients, axis=1, keepdims=True
)
descriptors_gradients = descriptors_gradients / norm_descriptors_gradients
# Draw random coefficients
coeffs = jax.random.multivariate_normal(
key,
shape=(self._batch_size,),
mean=self._mu,
cov=self._sigma,
)
coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
grads = jax.tree.map(
lambda x, y: jnp.concatenate((x, y), axis=-1),
fitness_gradients,
descriptors_gradients,
)
update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)
# update the genotypes
new_genotypes = jax.tree.map(lambda x, y: x + y, genotypes, update_grad)
return new_genotypes, {}
def state_update( # type: ignore
self,
emitter_state: OMGMEGAEmitterState,
repertoire: MapElitesRepertoire,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
"""Update the gradients repertoire to have the right gradients.
NOTE: see discussion in the class docstrings to see how this could
be improved.
Args:
emitter_state: current emitter state
repertoire: the current genotypes repertoire
genotypes: the genotypes of the batch of emitted offspring.
fitnesses: the fitnesses of the batch of emitted offspring.
descriptors: the descriptors of the emitted offspring.
extra_scores: a dictionary with other values outputted by the
scoring function.
Returns:
The modified emitter state.
"""
# get gradients out of the extra scores
assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
gradients = extra_scores["gradients"]
# update the gradients repertoire
gradients_repertoire = emitter_state.gradients_repertoire.add(
gradients,
descriptors,
fitnesses,
extra_scores,
)
return emitter_state.replace( # type: ignore
gradients_repertoire=gradients_repertoire
)
@property
def batch_size(self) -> int:
"""
Returns:
the batch size emitted by the emitter.
"""
return self._batch_size
|