MAP Elites class

This class implement the base mechanism of MAP-Elites. It must be used with an emitter. To get the usual MAP-Elites algorithm, one must use the mixing emitter.

The MAP-Elites class can be used with other emitters to create variants, like PGAME, DCRL-ME CMA-MEGA and OMG-MEGA.

Core elements of the MAP-Elites algorithm.

Note: Although very similar to the GeneticAlgorithm, we decided to keep the MAPElites class independent of the GeneticAlgorithm class at the moment to keep elements explicit.

Parameters:
  • scoring_function (Optional[Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]]) –

    a function that takes a batch of genotypes and compute their fitnesses and descriptors

  • emitter (Emitter) –

    an emitter is used to suggest offsprings given a MAPELites repertoire. It has two compulsory functions. A function that takes emits a new population, and a function that update the internal state of the emitter.

  • metrics_function (Callable[[MapElitesRepertoire], Metrics]) –

    a function that takes a MAP-Elites repertoire and compute any useful metric to track its evolution

Source code in qdax/core/map_elites.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class MAPElites:
    """Core elements of the MAP-Elites algorithm.

    Note: Although very similar to the GeneticAlgorithm, we decided to keep the
    MAPElites class independent of the GeneticAlgorithm class at the moment to keep
    elements explicit.

    Args:
        scoring_function: a function that takes a batch of genotypes and compute
            their fitnesses and descriptors
        emitter: an emitter is used to suggest offsprings given a MAPELites
            repertoire. It has two compulsory functions. A function that takes
            emits a new population, and a function that update the internal state
            of the emitter.
        metrics_function: a function that takes a MAP-Elites repertoire and compute
            any useful metric to track its evolution
    """

    def __init__(
        self,
        scoring_function: Optional[
            Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]
        ],
        emitter: Emitter,
        metrics_function: Callable[[MapElitesRepertoire], Metrics],
        repertoire_init: Callable[
            [Genotype, Fitness, Descriptor, Centroid, Optional[ExtraScores]],
            MapElitesRepertoire,
        ] = MapElitesRepertoire.init,
    ) -> None:
        self._scoring_function = scoring_function
        self._emitter = emitter
        self._metrics_function = metrics_function
        self._repertoire_init = repertoire_init

    def init(
        self,
        genotypes: Genotype,
        centroids: Centroid,
        key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter
        """
        if self._scoring_function is None:
            raise ValueError("Scoring function is not set.")

        # score initial genotypes
        key, subkey = jax.random.split(key)
        fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey)

        repertoire, emitter_state, metrics = self.init_ask_tell(
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            centroids=centroids,
            key=key,
            extra_scores=extra_scores,
        )
        return repertoire, emitter_state, metrics

    def init_ask_tell(
        self,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        centroids: Centroid,
        key: RNGKey,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes
        and their evaluations.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: initial fitnesses of the genotypes
            descriptors: initial descriptors of the genotypes
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            key: a random key used for stochastic operations.
            extra_scores: extra scores of the initial genotypes (optional)

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter.
        """
        if extra_scores is None:
            extra_scores = {}
        # init the repertoire
        repertoire = self._repertoire_init(
            genotypes,
            fitnesses,
            descriptors,
            centroids,
            extra_scores,
        )

        # get initial state of the emitter
        key, subkey = jax.random.split(key)
        emitter_state = self._emitter.init(
            key=subkey,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # calculate the initial metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

    def update(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """
        Performs one iteration of the MAP-Elites algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.


        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """
        if self._scoring_function is None:
            raise ValueError("Scoring function is not set.")

        # generate offsprings with the emitter
        key, subkey = jax.random.split(key)
        genotypes, extra_info = self.ask(repertoire, emitter_state, subkey)

        # scores the offsprings
        key, subkey = jax.random.split(key)
        (fitnesses, descriptors, extra_scores) = self._scoring_function(
            genotypes, subkey
        )

        repertoire, emitter_state, metrics = self.tell(
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            repertoire=repertoire,
            emitter_state=emitter_state,
            extra_scores=extra_scores,
            extra_info=extra_info,
        )
        return repertoire, emitter_state, metrics

    def scan_update(
        self,
        carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
        _: Any,
    ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive.

        Args:
            carry: a tuple containing the repertoire, the emitter state and a
                random key.
            _: unused element, necessary to respect jax.lax.scan API.

        Returns:
            The updated repertoire and emitter state, with a new random key and metrics.
        """
        repertoire, emitter_state, key = carry
        key, subkey = jax.random.split(key)
        (
            repertoire,
            emitter_state,
            metrics,
        ) = self.update(
            repertoire,
            emitter_state,
            subkey,
        )

        return (repertoire, emitter_state, key), metrics

    def ask(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """
        Ask the emitter to generate a new batch of genotypes.

        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            key: a jax PRNG random key
        """
        key, subkey = jax.random.split(key)
        genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)
        return genotypes, extra_info

    def tell(
        self,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        extra_scores: Optional[ExtraScores] = None,
        extra_info: Optional[ExtraScores] = None,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """
        Add new genotypes to the repertoire and update the emitter state.

        Args:
            genotypes: new genotypes to add to the repertoire
            fitnesses: fitnesses of the new genotypes
            descriptors: descriptors of the new genotypes
            extra_scores: extra scores of the new genotypes
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
        """
        if extra_scores is None:
            extra_scores = {}
        if extra_info is None:
            extra_info = {}
        # add genotypes in the repertoire
        repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores={**extra_scores, **extra_info},
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

ask(repertoire, emitter_state, key)

Ask the emitter to generate a new batch of genotypes.

Parameters:
  • repertoire (MapElitesRepertoire) –

    the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) –

    state of the emitter

  • key (RNGKey) –

    a jax PRNG random key

Source code in qdax/core/map_elites.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def ask(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """
    Ask the emitter to generate a new batch of genotypes.

    Args:
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
        key: a jax PRNG random key
    """
    key, subkey = jax.random.split(key)
    genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)
    return genotypes, extra_info

init(genotypes, centroids, key)

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • centroids (Centroid) –

    tessellation centroids of shape (batch_size, num_descriptors)

  • key (RNGKey) –

    a random key used for stochastic operations.

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]

    An initialized MAP-Elite repertoire with the initial state of the emitter

Source code in qdax/core/map_elites.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def init(
    self,
    genotypes: Genotype,
    centroids: Centroid,
    key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        key: a random key used for stochastic operations.

    Returns:
        An initialized MAP-Elite repertoire with the initial state of the emitter
    """
    if self._scoring_function is None:
        raise ValueError("Scoring function is not set.")

    # score initial genotypes
    key, subkey = jax.random.split(key)
    fitnesses, descriptors, extra_scores = self._scoring_function(genotypes, subkey)

    repertoire, emitter_state, metrics = self.init_ask_tell(
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        centroids=centroids,
        key=key,
        extra_scores=extra_scores,
    )
    return repertoire, emitter_state, metrics

init_ask_tell(genotypes, fitnesses, descriptors, centroids, key, extra_scores=None)

Initialize a Map-Elites repertoire with an initial population of genotypes and their evaluations. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) –

    initial fitnesses of the genotypes

  • descriptors (Descriptor) –

    initial descriptors of the genotypes

  • centroids (Centroid) –

    tessellation centroids of shape (batch_size, num_descriptors)

  • key (RNGKey) –

    a random key used for stochastic operations.

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of the initial genotypes (optional)

Returns:
  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]

    An initialized MAP-Elite repertoire with the initial state of the emitter.

Source code in qdax/core/map_elites.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def init_ask_tell(
    self,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    centroids: Centroid,
    key: RNGKey,
    extra_scores: Optional[ExtraScores] = None,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes
    and their evaluations.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: initial fitnesses of the genotypes
        descriptors: initial descriptors of the genotypes
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        key: a random key used for stochastic operations.
        extra_scores: extra scores of the initial genotypes (optional)

    Returns:
        An initialized MAP-Elite repertoire with the initial state of the emitter.
    """
    if extra_scores is None:
        extra_scores = {}
    # init the repertoire
    repertoire = self._repertoire_init(
        genotypes,
        fitnesses,
        descriptors,
        centroids,
        extra_scores,
    )

    # get initial state of the emitter
    key, subkey = jax.random.split(key)
    emitter_state = self._emitter.init(
        key=subkey,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # calculate the initial metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics

scan_update(carry, _)

Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.

Parameters:
  • carry (Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey]) –

    a tuple containing the repertoire, the emitter state and a random key.

  • _ (Any) –

    unused element, necessary to respect jax.lax.scan API.

Returns:
  • Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]

    The updated repertoire and emitter state, with a new random key and metrics.

Source code in qdax/core/map_elites.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def scan_update(
    self,
    carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
    _: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
    """Rewrites the update function in a way that makes it compatible with the
    jax.lax.scan primitive.

    Args:
        carry: a tuple containing the repertoire, the emitter state and a
            random key.
        _: unused element, necessary to respect jax.lax.scan API.

    Returns:
        The updated repertoire and emitter state, with a new random key and metrics.
    """
    repertoire, emitter_state, key = carry
    key, subkey = jax.random.split(key)
    (
        repertoire,
        emitter_state,
        metrics,
    ) = self.update(
        repertoire,
        emitter_state,
        subkey,
    )

    return (repertoire, emitter_state, key), metrics

tell(genotypes, fitnesses, descriptors, repertoire, emitter_state, extra_scores=None, extra_info=None)

Add new genotypes to the repertoire and update the emitter state.

Parameters:
  • genotypes (Genotype) –

    new genotypes to add to the repertoire

  • fitnesses (Fitness) –

    fitnesses of the new genotypes

  • descriptors (Descriptor) –

    descriptors of the new genotypes

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of the new genotypes

  • repertoire (MapElitesRepertoire) –

    the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) –

    state of the emitter

Source code in qdax/core/map_elites.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def tell(
    self,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    extra_scores: Optional[ExtraScores] = None,
    extra_info: Optional[ExtraScores] = None,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """
    Add new genotypes to the repertoire and update the emitter state.

    Args:
        genotypes: new genotypes to add to the repertoire
        fitnesses: fitnesses of the new genotypes
        descriptors: descriptors of the new genotypes
        extra_scores: extra scores of the new genotypes
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
    """
    if extra_scores is None:
        extra_scores = {}
    if extra_info is None:
        extra_info = {}
    # add genotypes in the repertoire
    repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores={**extra_scores, **extra_info},
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics

update(repertoire, emitter_state, key)

Performs one iteration of the MAP-Elites algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the repertoire.

Parameters:
  • repertoire (MapElitesRepertoire) –

    the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) –

    state of the emitter

  • key (RNGKey) –

    a jax PRNG random key

Returns:
  • MapElitesRepertoire

    the updated MAP-Elites repertoire

  • Optional[EmitterState]

    the updated (if needed) emitter state

  • Metrics

    metrics about the updated repertoire

  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]

    a new jax PRNG key

Source code in qdax/core/map_elites.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def update(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """
    Performs one iteration of the MAP-Elites algorithm.
    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.


    Args:
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
        key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """
    if self._scoring_function is None:
        raise ValueError("Scoring function is not set.")

    # generate offsprings with the emitter
    key, subkey = jax.random.split(key)
    genotypes, extra_info = self.ask(repertoire, emitter_state, subkey)

    # scores the offsprings
    key, subkey = jax.random.split(key)
    (fitnesses, descriptors, extra_scores) = self._scoring_function(
        genotypes, subkey
    )

    repertoire, emitter_state, metrics = self.tell(
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        repertoire=repertoire,
        emitter_state=emitter_state,
        extra_scores=extra_scores,
        extra_info=extra_info,
    )
    return repertoire, emitter_state, metrics

We also provide a class to have MAP-Elites efficiently distributed over several devices.

Bases: MAPElites

Source code in qdax/core/distributed_map_elites.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class DistributedMAPElites(MAPElites):
    def init(
        self,
        genotypes: Genotype,
        centroids: Centroid,
        key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Before the repertoire is initialised, individuals are gathered from all the
        devices.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            key: a random key used for stochastic operations.

        Returns:
            An initialized MAP-Elite repertoire with the initial state of the emitter,
            and a random key.
        """

        if self._scoring_function is None:
            raise ValueError("Scoring function is not set.")

        # score initial genotypes
        (fitnesses, descriptors, extra_scores) = self._scoring_function(genotypes, key)

        # gather across all devices
        (
            gathered_genotypes,
            gathered_fitnesses,
            gathered_descriptors,
        ) = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes, fitnesses, descriptors),
        )

        # init the repertoire
        repertoire = MapElitesRepertoire.init(
            genotypes=gathered_genotypes,
            fitnesses=gathered_fitnesses,
            descriptors=gathered_descriptors,
            centroids=centroids,
        )

        # get initial state of the emitter
        emitter_state = self._emitter.init(
            key=key,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # calculate the initial metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

    def update(
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """Performs one iteration of the MAP-Elites algorithm.

        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.

        Before the repertoire is updated, individuals are gathered from all the
        devices.

        Args:
            repertoire: the MAP-Elites repertoire
            emitter_state: state of the emitter
            key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """

        if self._scoring_function is None:
            raise ValueError("Scoring function is not set.")

        # generate offsprings with the emitter
        key, subkey = jax.random.split(key)
        genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)

        # scores the offsprings
        key, subkey = jax.random.split(key)
        (fitnesses, descriptors, extra_scores) = self._scoring_function(
            genotypes, subkey
        )

        # gather across all devices
        (
            gathered_genotypes,
            gathered_fitnesses,
            gathered_descriptors,
        ) = jax.tree.map(
            lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
            (genotypes, fitnesses, descriptors),
        )

        # add genotypes in the repertoire
        repertoire = repertoire.add(
            gathered_genotypes, gathered_descriptors, gathered_fitnesses
        )

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores={**extra_scores, **extra_info},
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

    def get_distributed_init_fn(
        self, centroids: Centroid, devices: List[Any]
    ) -> Callable[
        [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState]]
    ]:
        """Create a function that init MAP-Elites in a distributed way.

        Args:
            centroids: centroids that structure the repertoire.
            devices: hardware devices.

        Returns:
            A callable function that inits the MAP-Elites algorithm in a distributed
            way.
        """
        return jax.pmap(  # type: ignore
            partial(self.init, centroids=centroids),
            devices=devices,
            axis_name="p",
        )

    def get_distributed_update_fn(
        self, num_iterations: int, devices: List[Any]
    ) -> Callable[
        [MapElitesRepertoire, Optional[EmitterState], RNGKey],
        Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics],
    ]:
        """Create a function that can do a certain number of updates of
        MAP-Elites in a way that is distributed on several devices.

        Args:
            num_iterations: number of iterations to realize.
            devices: hardware devices to distribute on.

        Returns:
            The update function that can be called directly to apply a sequence
            of MAP-Elites updates.
        """

        @jax.jit
        def _scan_update(
            carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
            _: Any,
        ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
            """Rewrites the update function in a way that makes it compatible with the
            jax.lax.scan primitive."""
            # unwrap the input
            repertoire, emitter_state, key = carry

            # apply one step of update
            key, subkey = jax.random.split(key)
            (
                repertoire,
                emitter_state,
                metrics,
            ) = self.update(
                repertoire,
                emitter_state,
                subkey,
            )

            return (repertoire, emitter_state, key), metrics

        def update_fn(
            repertoire: MapElitesRepertoire,
            emitter_state: Optional[EmitterState],
            key: RNGKey,
        ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
            """Apply num_iterations of update."""
            (
                repertoire,
                emitter_state,
                key,
            ), metrics = jax.lax.scan(
                _scan_update,
                (repertoire, emitter_state, key),
                (),
                length=num_iterations,
            )
            return repertoire, emitter_state, metrics

        return jax.pmap(update_fn, devices=devices, axis_name="p")  # type: ignore

get_distributed_init_fn(centroids, devices)

Create a function that init MAP-Elites in a distributed way.

Parameters:
  • centroids (Centroid) –

    centroids that structure the repertoire.

  • devices (List[Any]) –

    hardware devices.

Returns:
  • Callable[[Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState]]]

    A callable function that inits the MAP-Elites algorithm in a distributed

  • Callable[[Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState]]]

    way.

Source code in qdax/core/distributed_map_elites.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_distributed_init_fn(
    self, centroids: Centroid, devices: List[Any]
) -> Callable[
    [Genotype, RNGKey], Tuple[MapElitesRepertoire, Optional[EmitterState]]
]:
    """Create a function that init MAP-Elites in a distributed way.

    Args:
        centroids: centroids that structure the repertoire.
        devices: hardware devices.

    Returns:
        A callable function that inits the MAP-Elites algorithm in a distributed
        way.
    """
    return jax.pmap(  # type: ignore
        partial(self.init, centroids=centroids),
        devices=devices,
        axis_name="p",
    )

get_distributed_update_fn(num_iterations, devices)

Create a function that can do a certain number of updates of MAP-Elites in a way that is distributed on several devices.

Parameters:
  • num_iterations (int) –

    number of iterations to realize.

  • devices (List[Any]) –

    hardware devices to distribute on.

Returns:
Source code in qdax/core/distributed_map_elites.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def get_distributed_update_fn(
    self, num_iterations: int, devices: List[Any]
) -> Callable[
    [MapElitesRepertoire, Optional[EmitterState], RNGKey],
    Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics],
]:
    """Create a function that can do a certain number of updates of
    MAP-Elites in a way that is distributed on several devices.

    Args:
        num_iterations: number of iterations to realize.
        devices: hardware devices to distribute on.

    Returns:
        The update function that can be called directly to apply a sequence
        of MAP-Elites updates.
    """

    @jax.jit
    def _scan_update(
        carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
        _: Any,
    ) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive."""
        # unwrap the input
        repertoire, emitter_state, key = carry

        # apply one step of update
        key, subkey = jax.random.split(key)
        (
            repertoire,
            emitter_state,
            metrics,
        ) = self.update(
            repertoire,
            emitter_state,
            subkey,
        )

        return (repertoire, emitter_state, key), metrics

    def update_fn(
        repertoire: MapElitesRepertoire,
        emitter_state: Optional[EmitterState],
        key: RNGKey,
    ) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
        """Apply num_iterations of update."""
        (
            repertoire,
            emitter_state,
            key,
        ), metrics = jax.lax.scan(
            _scan_update,
            (repertoire, emitter_state, key),
            (),
            length=num_iterations,
        )
        return repertoire, emitter_state, metrics

    return jax.pmap(update_fn, devices=devices, axis_name="p")  # type: ignore

init(genotypes, centroids, key)

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Before the repertoire is initialised, individuals are gathered from all the devices.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • centroids (Centroid) –

    tessellation centroids of shape (batch_size, num_descriptors)

  • key (RNGKey) –

    a random key used for stochastic operations.

Returns:
  • MapElitesRepertoire

    An initialized MAP-Elite repertoire with the initial state of the emitter,

  • Optional[EmitterState]

    and a random key.

Source code in qdax/core/distributed_map_elites.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def init(
    self,
    genotypes: Genotype,
    centroids: Centroid,
    key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Before the repertoire is initialised, individuals are gathered from all the
    devices.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        key: a random key used for stochastic operations.

    Returns:
        An initialized MAP-Elite repertoire with the initial state of the emitter,
        and a random key.
    """

    if self._scoring_function is None:
        raise ValueError("Scoring function is not set.")

    # score initial genotypes
    (fitnesses, descriptors, extra_scores) = self._scoring_function(genotypes, key)

    # gather across all devices
    (
        gathered_genotypes,
        gathered_fitnesses,
        gathered_descriptors,
    ) = jax.tree.map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes, fitnesses, descriptors),
    )

    # init the repertoire
    repertoire = MapElitesRepertoire.init(
        genotypes=gathered_genotypes,
        fitnesses=gathered_fitnesses,
        descriptors=gathered_descriptors,
        centroids=centroids,
    )

    # get initial state of the emitter
    emitter_state = self._emitter.init(
        key=key,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # calculate the initial metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics

update(repertoire, emitter_state, key)

Performs one iteration of the MAP-Elites algorithm.

  1. A batch of genotypes is sampled in the repertoire and the genotypes are copied.
  2. The copies are mutated and crossed-over
  3. The obtained offsprings are scored and then added to the repertoire.

Before the repertoire is updated, individuals are gathered from all the devices.

Parameters:
  • repertoire (MapElitesRepertoire) –

    the MAP-Elites repertoire

  • emitter_state (Optional[EmitterState]) –

    state of the emitter

  • key (RNGKey) –

    a jax PRNG random key

Returns:
  • MapElitesRepertoire

    the updated MAP-Elites repertoire

  • Optional[EmitterState]

    the updated (if needed) emitter state

  • Metrics

    metrics about the updated repertoire

  • Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]

    a new jax PRNG key

Source code in qdax/core/distributed_map_elites.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def update(
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: Optional[EmitterState],
    key: RNGKey,
) -> Tuple[MapElitesRepertoire, Optional[EmitterState], Metrics]:
    """Performs one iteration of the MAP-Elites algorithm.

    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.

    Before the repertoire is updated, individuals are gathered from all the
    devices.

    Args:
        repertoire: the MAP-Elites repertoire
        emitter_state: state of the emitter
        key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """

    if self._scoring_function is None:
        raise ValueError("Scoring function is not set.")

    # generate offsprings with the emitter
    key, subkey = jax.random.split(key)
    genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)

    # scores the offsprings
    key, subkey = jax.random.split(key)
    (fitnesses, descriptors, extra_scores) = self._scoring_function(
        genotypes, subkey
    )

    # gather across all devices
    (
        gathered_genotypes,
        gathered_fitnesses,
        gathered_descriptors,
    ) = jax.tree.map(
        lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
        (genotypes, fitnesses, descriptors),
    )

    # add genotypes in the repertoire
    repertoire = repertoire.add(
        gathered_genotypes, gathered_descriptors, gathered_fitnesses
    )

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores={**extra_scores, **extra_info},
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics