22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286 | class MAPElites:
"""Core elements of the MAP-Elites algorithm.
Note: Although very similar to the GeneticAlgorithm, we decided to keep the
MAPElites class independent of the GeneticAlgorithm class at the moment to keep
elements explicit.
Args:
scoring_function: a function that takes a batch of genotypes and compute
their fitnesses and descriptors
emitter: an emitter is used to suggest offsprings given a MAPELites
repertoire. It has two compulsory functions. A function that takes
emits a new population, and a function that update the internal state
of the emitter.
metrics_function: a function that takes a MAP-Elites repertoire and compute
any useful metric to track its evolution
"""
def __init__(
self,
scoring_function: Optional[
Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]
],
emitter: Emitter,
metrics_function: Callable[[MapElitesRepertoire], Metrics],
repertoire_init: Callable[
[Genotype, Fitness, Descriptor, Centroid, Optional[ExtraScores]],
MapElitesRepertoire,
] = MapElitesRepertoire.init,
) -> None:
self._scoring_function = scoring_function
self._emitter = emitter
self._metrics_function = metrics_function
self._repertoire_init = repertoire_init
def init(
self,
genotypes: Genotype,
centroids: Centroid,
key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
key: a random key used for stochastic operations.
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter
"""
if self._scoring_function is None:
raise ValueError("Scoring function is not set.")
# score initial genotypes
key, subkey = jax.random.split(key)
fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey)
repertoire, emitter_state, metrics = self.init_ask_tell(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
centroids=centroids,
key=key,
extra_scores=extra_scores,
)
return repertoire, emitter_state, metrics
def init_ask_tell(
self,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
key: RNGKey,
extra_scores: Optional[ExtraScores] = None,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes
and their evaluations.
Requires the definition of centroids that can be computed with any method
such as CVT or Euclidean mapping.
Args:
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
fitnesses: initial fitnesses of the genotypes
descriptors: initial descriptors of the genotypes
centroids: tessellation centroids of shape (batch_size, num_descriptors)
key: a random key used for stochastic operations.
extra_scores: extra scores of the initial genotypes (optional)
Returns:
An initialized MAP-Elite repertoire with the initial state of the emitter.
"""
if extra_scores is None:
extra_scores = {}
# init the repertoire
repertoire = self._repertoire_init(
genotypes,
fitnesses,
descriptors,
centroids,
extra_scores,
)
# get initial state of the emitter
key, subkey = jax.random.split(key)
emitter_state = self._emitter.init(
key=subkey,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
)
# calculate the initial metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics
def update(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
"""
Performs one iteration of the MAP-Elites algorithm.
1. A batch of genotypes is sampled in the repertoire and the genotypes
are copied.
2. The copies are mutated and crossed-over
3. The obtained offsprings are scored and then added to the repertoire.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
key: a jax PRNG random key
Returns:
the updated MAP-Elites repertoire
the updated (if needed) emitter state
metrics about the updated repertoire
a new jax PRNG key
"""
if self._scoring_function is None:
raise ValueError("Scoring function is not set.")
# generate offsprings with the emitter
key, subkey = jax.random.split(key)
genotypes, extra_info = self.ask(repertoire, emitter_state, subkey)
# scores the offsprings
key, subkey = jax.random.split(key)
(fitnesses, descriptors, extra_scores) = self._scoring_function(
genotypes, subkey
)
repertoire, emitter_state, metrics = self.tell(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
repertoire=repertoire,
emitter_state=emitter_state,
extra_scores=extra_scores,
extra_info=extra_info,
)
return repertoire, emitter_state, metrics
def scan_update(
self,
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
_: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.
Args:
carry: a tuple containing the repertoire, the emitter state and a
random key.
_: unused element, necessary to respect jax.lax.scan API.
Returns:
The updated repertoire and emitter state, with a new random key and metrics.
"""
repertoire, emitter_state, key = carry
key, subkey = jax.random.split(key)
(
repertoire,
emitter_state,
metrics,
) = self.update(
repertoire,
emitter_state,
subkey,
)
return (repertoire, emitter_state, key), metrics
def ask(
self,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
"""
Ask the emitter to generate a new batch of genotypes.
Args:
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
key: a jax PRNG random key
"""
key, subkey = jax.random.split(key)
genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)
return genotypes, extra_info
def tell(
self,
genotypes: Genotype,
fitnesses: Fitness,
descriptors: Descriptor,
repertoire: MapElitesRepertoire,
emitter_state: Optional[EmitterState],
extra_scores: Optional[ExtraScores] = None,
extra_info: Optional[ExtraScores] = None,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
"""
Add new genotypes to the repertoire and update the emitter state.
Args:
genotypes: new genotypes to add to the repertoire
fitnesses: fitnesses of the new genotypes
descriptors: descriptors of the new genotypes
extra_scores: extra scores of the new genotypes
repertoire: the MAP-Elites repertoire
emitter_state: state of the emitter
"""
if extra_scores is None:
extra_scores = {}
if extra_info is None:
extra_info = {}
# add genotypes in the repertoire
repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores={**extra_scores, **extra_info},
)
# update the metrics
metrics = self._metrics_function(repertoire)
return repertoire, emitter_state, metrics
|