Covariance Matrix Adaptation MAP Elites (CMAME)

To create an instance of CMAME, one need to use an instance of MAP-Elites with the desired CMA Emitter - optimizing, random direction, improvement - detailed below.To use the pool of emitter mechanism, use the CMAPoolEmitter.

Three emitter types:

Bases: Emitter, ABC

Source code in qdax/core/emitters/cma_emitter.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class CMAEmitter(Emitter, ABC):
    def __init__(
        self,
        batch_size: int,
        genotype_dim: int,
        centroids: Centroid,
        sigma_g: float,
        min_count: Optional[int] = None,
        max_count: Optional[float] = None,
        selector: Optional[Selector] = None,
    ):
        """
        Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
        Rapid Illumination of Descriptor Space" by Fontaine et al.

        Args:
            batch_size: number of solutions sampled at each iteration
            genotype_dim: dimension of the genotype space.
            centroids: centroids used for the repertoire.
            sigma_g: standard deviation for the coefficients - called step size.
            min_count: minimum number of CMAES opt step before being considered for
                reinitialisation.
            max_count: maximum number of CMAES opt step authorized.
        """
        self._batch_size = batch_size

        # define a CMAES instance
        self._cmaes = CMAES(
            population_size=batch_size,
            search_dim=genotype_dim,
            # no need for fitness function in that specific case
            fitness_function=None,  # type: ignore
            num_best=batch_size,
            init_sigma=sigma_g,
            mean_init=None,  # will be init at zeros in cmaes
            bias_weights=True,
            delay_eigen_decomposition=True,
        )

        # minimum number of emitted solution before an emitter can be re-initialized
        if min_count is None:
            min_count = 0

        self._min_count = min_count

        if max_count is None:
            max_count = jnp.inf

        self._max_count = max_count

        self._centroids = centroids

        self._cma_initial_state = self._cmaes.init()

        self._selector = selector

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size

    def init(
        self,
        key: RNGKey,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> CMAEmitterState:
        """
        Initializes the CMA-MEGA emitter


        Args:
            genotypes: initial genotypes to add to the grid.
            key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

        # return the initial state
        key, subkey = jax.random.split(key)
        emitter_state = CMAEmitterState(
            key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
            emit_count=0,
        )
        return emitter_state

    def emit(  # type: ignore
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAEmitterState,
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """
        Emits new individuals. Interestingly, this method does not directly modifies
        individuals from the repertoire but sample from a distribution. Hence the
        repertoire is not used in the emit function.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            key: a random key to handle random operations.

        Returns:
            New genotypes and a new random key.
        """
        # emit from CMA-ES
        offsprings = self._cmaes.sample(cmaes_state=emitter_state.cmaes_state, key=key)

        return offsprings, {}

    def state_update(  # type: ignore
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> CMAEmitterState:
        """
        Updates the CMA-ME emitter state.

        Note: we use the update_state function from CMAES, a function that assumes
        that the candidates are already sorted. We do this because we have to sort
        them in this function anyway, in order to apply the right weights to the
        terms when update theta.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """

        # retrieve elements from the emitter state
        cmaes_state = emitter_state.cmaes_state

        # Compute the improvements - needed for re-init condition
        indices = get_cells_indices(descriptors, repertoire.centroids)
        improvements = (
            fitnesses - emitter_state.previous_fitnesses.squeeze(axis=1)[indices]
        )

        ranking_criteria = self._ranking_criteria(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
            improvements=improvements,
        )

        # get the indices
        sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

        # sort the candidates
        sorted_candidates = jax.tree.map(lambda x: x[sorted_indices], genotypes)
        sorted_improvements = improvements[sorted_indices]

        # compute reinitialize condition
        emit_count = emitter_state.emit_count + 1

        # check if the criteria are too similar
        sorted_criteria = ranking_criteria[sorted_indices]
        flat_criteria_condition = (
            jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
        )

        # check all conditions
        reinitialize = (
            jnp.all(improvements < 0) * (emit_count > self._min_count)
            + (emit_count > self._max_count)
            + self._cmaes.stop_condition(cmaes_state)
            + flat_criteria_condition
        )

        # If true, draw randomly and re-initialize parameters
        def update_and_reinit(
            operand: Tuple[
                CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
            ],
        ) -> CMAEmitterState:
            return self._update_and_init_emitter_state(*operand)

        def update_wo_reinit(
            operand: Tuple[
                CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
            ],
        ) -> CMAEmitterState:
            """Update the emitter when no reinit event happened.

            Here lies a divergence compared to the original implementation. We
            are getting better results when using no mask and doing the update
            with the whole batch of individuals rather than keeping only the one
            than were added to the archive.

            Interestingly, keeping the best half was not doing better. We think that
            this might be due to the small batch size used.

            This applies for the setting from the paper CMA-ME. Those facts might
            not be true with other problems and hyperparameters.

            To replicate the code described in the paper, replace:
            `mask = jnp.ones_like(sorted_improvements)`

            by:
            ```
            mask = sorted_improvements >= 0
            mask = mask + 1e-6
            ```

            RMQ: the addition of 1e-6 is here to fix a numerical
            instability.
            """

            cmaes_state, emitter_state, _, emit_count, _ = operand

            # Update CMA Parameters
            mask = jnp.ones_like(sorted_improvements)

            cmaes_state = self._cmaes.update_state_with_mask(
                cmaes_state, sorted_candidates, mask=mask
            )

            emitter_state = emitter_state.replace(
                cmaes_state=cmaes_state,
                emit_count=emit_count,
            )

            return emitter_state  # type: ignore

        # Update CMA Parameters
        key = emitter_state.key
        key, subkey = jax.random.split(key)
        emitter_state = jax.lax.cond(
            reinitialize,
            update_and_reinit,
            update_wo_reinit,
            operand=(
                cmaes_state,
                emitter_state,
                repertoire,
                emit_count,
                subkey,
            ),
        )

        # update the emitter state
        emitter_state = emitter_state.replace(
            previous_fitnesses=repertoire.fitnesses,
            key=key,
        )

        return emitter_state

    def _update_and_init_emitter_state(
        self,
        cmaes_state: CMAESState,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        emit_count: int,
        key: RNGKey,
    ) -> CMAEmitterState:
        """Update the emitter state in the case of a reinit event.
        Reinit the cmaes state and use an individual from the repertoire
        as the starting mean.

        Args:
            cmaes_state: current cmaes state
            emitter_state: current cmame state
            repertoire: most recent repertoire
            emit_count: counter of the emitter
            key: key to handle stochastic events

        Returns:
            The updated emitter state.
        """

        # re-sample
        random_genotype = repertoire.select(
            key, num_samples=1, selector=self._selector
        ).genotypes

        # remove the batch dim
        new_mean = jax.tree.map(lambda x: x.squeeze(0), random_genotype)

        cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_init_state, emit_count=0
        )

        return emitter_state  # type: ignore

    @abstractmethod
    def _ranking_criteria(
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jax.Array,
    ) -> jax.Array:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvements of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it's the improvement, or the fitness when the cell was previously
            unoccupied. Additionally, genotypes that discovered a new cell are
            given on offset to be ranked in front of other genotypes.
        """

        pass

batch_size property

Returns:
  • int

    the batch size emitted by the emitter.

__init__(batch_size, genotype_dim, centroids, sigma_g, min_count=None, max_count=None, selector=None)

Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the Rapid Illumination of Descriptor Space" by Fontaine et al.

Parameters:
  • batch_size (int) –

    number of solutions sampled at each iteration

  • genotype_dim (int) –

    dimension of the genotype space.

  • centroids (Centroid) –

    centroids used for the repertoire.

  • sigma_g (float) –

    standard deviation for the coefficients - called step size.

  • min_count (Optional[int], default: None ) –

    minimum number of CMAES opt step before being considered for reinitialisation.

  • max_count (Optional[float], default: None ) –

    maximum number of CMAES opt step authorized.

Source code in qdax/core/emitters/cma_emitter.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __init__(
    self,
    batch_size: int,
    genotype_dim: int,
    centroids: Centroid,
    sigma_g: float,
    min_count: Optional[int] = None,
    max_count: Optional[float] = None,
    selector: Optional[Selector] = None,
):
    """
    Class for the emitter of CMA ME from "Covariance Matrix Adaptation for the
    Rapid Illumination of Descriptor Space" by Fontaine et al.

    Args:
        batch_size: number of solutions sampled at each iteration
        genotype_dim: dimension of the genotype space.
        centroids: centroids used for the repertoire.
        sigma_g: standard deviation for the coefficients - called step size.
        min_count: minimum number of CMAES opt step before being considered for
            reinitialisation.
        max_count: maximum number of CMAES opt step authorized.
    """
    self._batch_size = batch_size

    # define a CMAES instance
    self._cmaes = CMAES(
        population_size=batch_size,
        search_dim=genotype_dim,
        # no need for fitness function in that specific case
        fitness_function=None,  # type: ignore
        num_best=batch_size,
        init_sigma=sigma_g,
        mean_init=None,  # will be init at zeros in cmaes
        bias_weights=True,
        delay_eigen_decomposition=True,
    )

    # minimum number of emitted solution before an emitter can be re-initialized
    if min_count is None:
        min_count = 0

    self._min_count = min_count

    if max_count is None:
        max_count = jnp.inf

    self._max_count = max_count

    self._centroids = centroids

    self._cma_initial_state = self._cmaes.init()

    self._selector = selector

emit(repertoire, emitter_state, key)

Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) –

    a repertoire of genotypes (unused).

  • emitter_state (CMAEmitterState) –

    the state of the CMA-MEGA emitter.

  • key (RNGKey) –

    a random key to handle random operations.

Returns:
  • Tuple[Genotype, ExtraScores]

    New genotypes and a new random key.

Source code in qdax/core/emitters/cma_emitter.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def emit(  # type: ignore
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAEmitterState,
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """
    Emits new individuals. Interestingly, this method does not directly modifies
    individuals from the repertoire but sample from a distribution. Hence the
    repertoire is not used in the emit function.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        key: a random key to handle random operations.

    Returns:
        New genotypes and a new random key.
    """
    # emit from CMA-ES
    offsprings = self._cmaes.sample(cmaes_state=emitter_state.cmaes_state, key=key)

    return offsprings, {}

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initializes the CMA-MEGA emitter

Parameters:
  • genotypes (Genotype) –

    initial genotypes to add to the grid.

  • key (RNGKey) –

    a random key to handle stochastic operations.

Returns:
  • CMAEmitterState

    The initial state of the emitter.

Source code in qdax/core/emitters/cma_emitter.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def init(
    self,
    key: RNGKey,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> CMAEmitterState:
    """
    Initializes the CMA-MEGA emitter


    Args:
        genotypes: initial genotypes to add to the grid.
        key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

    # return the initial state
    key, subkey = jax.random.split(key)
    emitter_state = CMAEmitterState(
        key=subkey,
        cmaes_state=self._cma_initial_state,
        previous_fitnesses=default_fitnesses,
        emit_count=0,
    )
    return emitter_state

state_update(emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the CMA-ME emitter state.

Note: we use the update_state function from CMAES, a function that assumes that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.

Parameters:
  • emitter_state (CMAEmitterState) –

    current emitter state

  • repertoire (MapElitesRepertoire) –

    the current genotypes repertoire

  • genotypes (Genotype) –

    the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) –

    the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) –

    the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores], default: None ) –

    unused

Returns:
  • CMAEmitterState

    The updated emitter state.

Source code in qdax/core/emitters/cma_emitter.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def state_update(  # type: ignore
    self,
    emitter_state: CMAEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> CMAEmitterState:
    """
    Updates the CMA-ME emitter state.

    Note: we use the update_state function from CMAES, a function that assumes
    that the candidates are already sorted. We do this because we have to sort
    them in this function anyway, in order to apply the right weights to the
    terms when update theta.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """

    # retrieve elements from the emitter state
    cmaes_state = emitter_state.cmaes_state

    # Compute the improvements - needed for re-init condition
    indices = get_cells_indices(descriptors, repertoire.centroids)
    improvements = (
        fitnesses - emitter_state.previous_fitnesses.squeeze(axis=1)[indices]
    )

    ranking_criteria = self._ranking_criteria(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
        improvements=improvements,
    )

    # get the indices
    sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

    # sort the candidates
    sorted_candidates = jax.tree.map(lambda x: x[sorted_indices], genotypes)
    sorted_improvements = improvements[sorted_indices]

    # compute reinitialize condition
    emit_count = emitter_state.emit_count + 1

    # check if the criteria are too similar
    sorted_criteria = ranking_criteria[sorted_indices]
    flat_criteria_condition = (
        jnp.linalg.norm(sorted_criteria[0] - sorted_criteria[-1]) < 1e-12
    )

    # check all conditions
    reinitialize = (
        jnp.all(improvements < 0) * (emit_count > self._min_count)
        + (emit_count > self._max_count)
        + self._cmaes.stop_condition(cmaes_state)
        + flat_criteria_condition
    )

    # If true, draw randomly and re-initialize parameters
    def update_and_reinit(
        operand: Tuple[
            CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
        ],
    ) -> CMAEmitterState:
        return self._update_and_init_emitter_state(*operand)

    def update_wo_reinit(
        operand: Tuple[
            CMAESState, CMAEmitterState, MapElitesRepertoire, int, RNGKey
        ],
    ) -> CMAEmitterState:
        """Update the emitter when no reinit event happened.

        Here lies a divergence compared to the original implementation. We
        are getting better results when using no mask and doing the update
        with the whole batch of individuals rather than keeping only the one
        than were added to the archive.

        Interestingly, keeping the best half was not doing better. We think that
        this might be due to the small batch size used.

        This applies for the setting from the paper CMA-ME. Those facts might
        not be true with other problems and hyperparameters.

        To replicate the code described in the paper, replace:
        `mask = jnp.ones_like(sorted_improvements)`

        by:
        ```
        mask = sorted_improvements >= 0
        mask = mask + 1e-6
        ```

        RMQ: the addition of 1e-6 is here to fix a numerical
        instability.
        """

        cmaes_state, emitter_state, _, emit_count, _ = operand

        # Update CMA Parameters
        mask = jnp.ones_like(sorted_improvements)

        cmaes_state = self._cmaes.update_state_with_mask(
            cmaes_state, sorted_candidates, mask=mask
        )

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_state,
            emit_count=emit_count,
        )

        return emitter_state  # type: ignore

    # Update CMA Parameters
    key = emitter_state.key
    key, subkey = jax.random.split(key)
    emitter_state = jax.lax.cond(
        reinitialize,
        update_and_reinit,
        update_wo_reinit,
        operand=(
            cmaes_state,
            emitter_state,
            repertoire,
            emit_count,
            subkey,
        ),
    )

    # update the emitter state
    emitter_state = emitter_state.replace(
        previous_fitnesses=repertoire.fitnesses,
        key=key,
    )

    return emitter_state

Bases: CMAEmitter

Source code in qdax/core/emitters/cma_rnd_emitter.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class CMARndEmitter(CMAEmitter):
    def init(
        self,
        key: RNGKey,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> CMARndEmitterState:
        """
        Initializes the CMA-MEGA emitter


        Args:
            genotypes: initial genotypes to add to the grid.
            key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

        # take a random direction
        key, subkey = jax.random.split(key)
        random_direction = jax.random.uniform(
            subkey,
            shape=(self._centroids.shape[-1],),
        )

        # return the initial state
        key, subkey = jax.random.split(key)

        emitter_state = CMARndEmitterState(
            key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
            emit_count=0,
            random_direction=random_direction,
        )

        return emitter_state

    def _update_and_init_emitter_state(
        self,
        cmaes_state: CMAESState,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        emit_count: int,
        key: RNGKey,
    ) -> CMAEmitterState:
        """Update the emitter state in the case of a reinit event.
        Reinit the cmaes state and use an individual from the repertoire
        as the starting mean.

        Args:
            cmaes_state: current cmaes state
            emitter_state: current cmame state
            repertoire: most recent repertoire
            emit_count: counter of the emitter
            key: key to handle stochastic events

        Returns:
            The updated emitter state.
        """

        # re-sample
        key, subkey = jax.random.split(key)
        random_genotype = repertoire.select(
            subkey, num_samples=1, selector=self._selector
        ).genotypes

        # get new mean - remove the batch dim
        new_mean = jax.tree.map(lambda x: x.squeeze(0), random_genotype)

        # define the corresponding cmaes init state
        cmaes_init_state = self._cma_initial_state.replace(mean=new_mean, num_updates=0)

        # take a new random direction
        random_direction = jax.random.uniform(
            key,
            shape=(self._centroids.shape[-1],),
        )

        emitter_state = emitter_state.replace(
            cmaes_state=cmaes_init_state,
            emit_count=0,
            random_direction=random_direction,
        )

        return emitter_state  # type: ignore

    def _ranking_criteria(
        self,
        emitter_state: CMARndEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jax.Array,
    ) -> jax.Array:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvements of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it is the dot product of the descriptor with the current random
            direction.
        """

        # criteria: projection of the descriptors along the random direction
        ranking_criteria = jnp.dot(descriptors, emitter_state.random_direction)

        # make sure to have all the new cells first
        new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

        # condition for being a new cell
        condition = improvements == jnp.inf

        ranking_criteria = jnp.where(
            condition, ranking_criteria + new_cell_offset, ranking_criteria
        )

        return ranking_criteria  # type: ignore

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initializes the CMA-MEGA emitter

Parameters:
  • genotypes (Genotype) –

    initial genotypes to add to the grid.

  • key (RNGKey) –

    a random key to handle stochastic operations.

Returns:
  • CMARndEmitterState

    The initial state of the emitter.

Source code in qdax/core/emitters/cma_rnd_emitter.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def init(
    self,
    key: RNGKey,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> CMARndEmitterState:
    """
    Initializes the CMA-MEGA emitter


    Args:
        genotypes: initial genotypes to add to the grid.
        key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

    # take a random direction
    key, subkey = jax.random.split(key)
    random_direction = jax.random.uniform(
        subkey,
        shape=(self._centroids.shape[-1],),
    )

    # return the initial state
    key, subkey = jax.random.split(key)

    emitter_state = CMARndEmitterState(
        key=subkey,
        cmaes_state=self._cma_initial_state,
        previous_fitnesses=default_fitnesses,
        emit_count=0,
        random_direction=random_direction,
    )

    return emitter_state

Bases: CMAEmitter

Source code in qdax/core/emitters/cma_opt_emitter.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class CMAOptimizingEmitter(CMAEmitter):
    def _ranking_criteria(
        self,
        emitter_state: CMAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores],
        improvements: jax.Array,
    ) -> jax.Array:
        """Defines how the genotypes should be sorted. Impacts the update
        of the CMAES state. In the end, this defines the type of CMAES emitter
        used (optimizing, random direction or improvement).

        Args:
            emitter_state: current state of the emitter.
            repertoire: latest repertoire of genotypes.
            genotypes: emitted genotypes.
            fitnesses: corresponding fitnesses.
            descriptors: corresponding fitnesses.
            extra_scores: corresponding extra scores.
            improvements: improvements of the emitted genotypes. This corresponds
                to the difference between their fitness and the fitness of the
                individual occupying the cell of corresponding fitness.

        Returns:
            The values to take into account in order to rank the emitted genotypes.
            Here, it is the fitness of the genotype.
        """

        return fitnesses

Pool of homogeneous emitters:

Bases: Emitter

Source code in qdax/core/emitters/cma_pool_emitter.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class CMAPoolEmitter(Emitter):
    def __init__(self, num_states: int, emitter: CMAEmitter):
        """Instantiate a pool of homogeneous emitters.

        Args:
            num_states: the number of emitters to consider. We can use a
                single emitter object and a batched emitter state.
            emitter: the type of emitter for the pool.
        """
        self._num_states = num_states
        self._emitter = emitter

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._emitter.batch_size  # type: ignore

    def init(
        self,
        key: RNGKey,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> CMAPoolEmitterState:
        """
        Initializes the CMA-MEGA emitter


        Args:
            genotypes: initial genotypes to add to the grid.
            key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        def scan_emitter_init(carry: RNGKey, _: Any) -> Tuple[RNGKey, CMAEmitterState]:
            key = carry
            key, subkey = jax.random.split(key)
            emitter_state = self._emitter.init(
                subkey,
                repertoire,
                genotypes,
                fitnesses,
                descriptors,
                extra_scores,
            )
            return key, emitter_state

        # init all the emitter states
        key, emitter_states = jax.lax.scan(
            scan_emitter_init, key, (), length=self._num_states
        )

        # define the emitter state of the pool
        emitter_state = CMAPoolEmitterState(
            current_index=0, emitter_states=emitter_states
        )

        return emitter_state

    def emit(  # type: ignore
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAPoolEmitterState,
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """
        Emits new individuals.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            key: a random key to handle random operations.

        Returns:
            New genotypes and extra infos.
        """

        # retrieve the relevant emitter state
        current_index = emitter_state.current_index
        used_emitter_state = jax.tree.map(
            lambda x: x[current_index], emitter_state.emitter_states
        )

        # use it to emit offsprings
        offsprings, extra_info = self._emitter.emit(repertoire, used_emitter_state, key)

        return offsprings, extra_info

    def state_update(  # type: ignore
        self,
        emitter_state: CMAPoolEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """
        Updates the emitter state.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """

        # retrieve the emitter that has been used and it's emitter state
        current_index = emitter_state.current_index
        emitter_states = emitter_state.emitter_states

        used_emitter_state = jax.tree.map(lambda x: x[current_index], emitter_states)

        # update the used emitter state
        used_emitter_state = self._emitter.state_update(
            emitter_state=used_emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=descriptors,
            extra_scores=extra_scores,
        )

        # update the emitter state
        emitter_states = jax.tree.map(
            lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
        )

        # determine the next emitter to be used
        emit_counts = emitter_states.emit_count

        new_index = jnp.argmin(emit_counts)

        emitter_state = emitter_state.replace(
            current_index=new_index, emitter_states=emitter_states
        )

        return emitter_state  # type: ignore

batch_size property

Returns:
  • int

    the batch size emitted by the emitter.

__init__(num_states, emitter)

Instantiate a pool of homogeneous emitters.

Parameters:
  • num_states (int) –

    the number of emitters to consider. We can use a single emitter object and a batched emitter state.

  • emitter (CMAEmitter) –

    the type of emitter for the pool.

Source code in qdax/core/emitters/cma_pool_emitter.py
30
31
32
33
34
35
36
37
38
39
def __init__(self, num_states: int, emitter: CMAEmitter):
    """Instantiate a pool of homogeneous emitters.

    Args:
        num_states: the number of emitters to consider. We can use a
            single emitter object and a batched emitter state.
        emitter: the type of emitter for the pool.
    """
    self._num_states = num_states
    self._emitter = emitter

emit(repertoire, emitter_state, key)

Emits new individuals.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) –

    a repertoire of genotypes (unused).

  • emitter_state (CMAPoolEmitterState) –

    the state of the CMA-MEGA emitter.

  • key (RNGKey) –

    a random key to handle random operations.

Returns:
  • Tuple[Genotype, ExtraScores]

    New genotypes and extra infos.

Source code in qdax/core/emitters/cma_pool_emitter.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def emit(  # type: ignore
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAPoolEmitterState,
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """
    Emits new individuals.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        key: a random key to handle random operations.

    Returns:
        New genotypes and extra infos.
    """

    # retrieve the relevant emitter state
    current_index = emitter_state.current_index
    used_emitter_state = jax.tree.map(
        lambda x: x[current_index], emitter_state.emitter_states
    )

    # use it to emit offsprings
    offsprings, extra_info = self._emitter.emit(repertoire, used_emitter_state, key)

    return offsprings, extra_info

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initializes the CMA-MEGA emitter

Parameters:
  • genotypes (Genotype) –

    initial genotypes to add to the grid.

  • key (RNGKey) –

    a random key to handle stochastic operations.

Returns:
  • CMAPoolEmitterState

    The initial state of the emitter.

Source code in qdax/core/emitters/cma_pool_emitter.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def init(
    self,
    key: RNGKey,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> CMAPoolEmitterState:
    """
    Initializes the CMA-MEGA emitter


    Args:
        genotypes: initial genotypes to add to the grid.
        key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    def scan_emitter_init(carry: RNGKey, _: Any) -> Tuple[RNGKey, CMAEmitterState]:
        key = carry
        key, subkey = jax.random.split(key)
        emitter_state = self._emitter.init(
            subkey,
            repertoire,
            genotypes,
            fitnesses,
            descriptors,
            extra_scores,
        )
        return key, emitter_state

    # init all the emitter states
    key, emitter_states = jax.lax.scan(
        scan_emitter_init, key, (), length=self._num_states
    )

    # define the emitter state of the pool
    emitter_state = CMAPoolEmitterState(
        current_index=0, emitter_states=emitter_states
    )

    return emitter_state

state_update(emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the emitter state.

Parameters:
  • emitter_state (CMAPoolEmitterState) –

    current emitter state

  • repertoire (MapElitesRepertoire) –

    the current genotypes repertoire

  • genotypes (Genotype) –

    the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) –

    the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) –

    the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores], default: None ) –

    unused

Returns:
  • Optional[EmitterState]

    The updated emitter state.

Source code in qdax/core/emitters/cma_pool_emitter.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def state_update(  # type: ignore
    self,
    emitter_state: CMAPoolEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """
    Updates the emitter state.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """

    # retrieve the emitter that has been used and it's emitter state
    current_index = emitter_state.current_index
    emitter_states = emitter_state.emitter_states

    used_emitter_state = jax.tree.map(lambda x: x[current_index], emitter_states)

    # update the used emitter state
    used_emitter_state = self._emitter.state_update(
        emitter_state=used_emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=descriptors,
        extra_scores=extra_scores,
    )

    # update the emitter state
    emitter_states = jax.tree.map(
        lambda x, y: x.at[current_index].set(y), emitter_states, used_emitter_state
    )

    # determine the next emitter to be used
    emit_counts = emitter_states.emit_count

    new_index = jnp.argmin(emit_counts)

    emitter_state = emitter_state.replace(
        current_index=new_index, emitter_states=emitter_states
    )

    return emitter_state  # type: ignore