Covariance Matrix Adaptation MAP-Elites via Gradient Arborescence (CMA-MEGA)

To create an instance of CMA-MEGA, one need to use an instance of MAP-Elites with the CMAMEGAEmitter, detailed below.

Bases: Emitter

Source code in qdax/core/emitters/cma_mega_emitter.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class CMAMEGAEmitter(Emitter):
    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]
        ],
        batch_size: int,
        learning_rate: float,
        num_descriptors: int,
        centroids: Centroid,
        sigma_g: float,
        selector: Optional[Selector] = None,
    ):
        """
        Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
        Fontaine et al.

        Args:
            scoring_function: a function to score individuals, outputting fitness,
                descriptors and extra scores. With this emitter, the extra score
                contains gradients and normalized gradients.
            batch_size: number of solutions sampled at each iteration
            learning_rate: rate at which the mean of the distribution is updated.
            num_descriptors: number of descriptors
            centroids: centroids of the repertoire used to store the genotypes
            sigma_g: standard deviation for the coefficients
        """

        self._scoring_function = scoring_function
        self._batch_size = batch_size
        self._learning_rate = learning_rate

        # weights used to update the gradient direction through a linear combination
        self._weights = jnp.expand_dims(
            jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
        )
        self._weights = self._weights / (self._weights.sum())

        # define a CMAES instance - used to update the coeffs
        self._cmaes = CMAES(
            population_size=batch_size,
            search_dim=num_descriptors + 1,
            # no need for fitness function in that specific case
            fitness_function=None,  # type: ignore
            num_best=batch_size,
            init_sigma=sigma_g,
            bias_weights=True,
            delay_eigen_decomposition=True,
        )

        self._centroids = centroids

        self._cma_initial_state = self._cmaes.init()

        self._selector = selector

    def init(
        self,
        key: RNGKey,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> CMAMEGAState:
        """
        Initializes the CMA-MEGA emitter.


        Args:
            genotypes: initial genotypes to add to the grid.
            key: a random key to handle stochastic operations.

        Returns:
            The initial state of the emitter.
        """

        # define init theta as 0
        theta = jax.tree.map(
            lambda x: jnp.zeros_like(x[:1, ...]),
            genotypes,
        )

        # score it
        _, _, extra_score = self._scoring_function(theta, key)
        theta_grads = extra_score["normalized_grads"]

        # Initialize repertoire with default values
        num_centroids = self._centroids.shape[0]
        default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

        # return the initial state
        key, subkey = jax.random.split(key)
        emitter_state = CMAMEGAState(
            theta=theta,
            theta_grads=theta_grads,
            key=subkey,
            cmaes_state=self._cma_initial_state,
            previous_fitnesses=default_fitnesses,
        )
        return emitter_state

    def emit(  # type: ignore
        self,
        repertoire: Optional[MapElitesRepertoire],
        emitter_state: CMAMEGAState,
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """
        Emits new individuals. Interestingly, this method does not directly modifies
        individuals from the repertoire but sample from a distribution. Hence the
        repertoire is not used in the emit function.

        Args:
            repertoire: a repertoire of genotypes (unused).
            emitter_state: the state of the CMA-MEGA emitter.
            key: a random key to handle random operations.

        Returns:
            New genotypes.
        """

        # retrieve elements from the emitter state
        theta = jnp.nan_to_num(emitter_state.theta)
        cmaes_state = emitter_state.cmaes_state

        # get grads - remove nan and first dimension
        grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))

        # Draw random coefficients - use the emitter state key
        coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=key)

        # make sure the fitness coefficient is positive
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        update_grad = coeffs @ grads.T

        # Compute new candidates
        new_thetas = jax.tree.map(lambda x, y: x + y, theta, update_grad)

        return new_thetas, {}

    def state_update(  # type: ignore
        self,
        emitter_state: CMAMEGAState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: Optional[ExtraScores] = None,
    ) -> Optional[EmitterState]:
        """
        Updates the CMA-MEGA emitter state.

        Note: in order to recover the coeffs that where used to sample the genotypes,
        we reuse the emitter state's random key in this function.

        Note: we use the update_state function from CMAES, a function that suppose
        that the candidates are already sorted. We do this because we have to sort
        them in this function anyway, in order to apply the right weights to the
        terms when update theta.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring (unused).
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: unused

        Returns:
            The updated emitter state.
        """
        key = emitter_state.key

        # retrieve elements from the emitter state
        cmaes_state = emitter_state.cmaes_state
        theta = jnp.nan_to_num(emitter_state.theta)
        grads = jnp.nan_to_num(emitter_state.theta_grads[0])

        # Update the archive and compute the improvements
        indices = get_cells_indices(descriptors, repertoire.centroids)
        improvements = (
            fitnesses - emitter_state.previous_fitnesses.squeeze(axis=1)[indices]
        )

        # condition for being a new cell
        condition = improvements == jnp.inf

        # criteria: fitness if new cell, improvement else
        ranking_criteria = jnp.where(condition, fitnesses, improvements)

        # make sure to have all the new cells first
        new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

        ranking_criteria = jnp.where(
            condition, ranking_criteria + new_cell_offset, ranking_criteria
        )

        # sort indices according to the criteria
        sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

        # Draw the coeffs - reuse the emitter state key to get same coeffs
        key, subkey = jax.random.split(key)
        coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=subkey)
        # make sure the fitness coeff is positive
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))

        # get the gradients that must be applied
        update_grad = coeffs @ grads.T

        # weight terms - based on improvement rank
        gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)

        # update theta
        theta = jax.tree.map(
            lambda x, y: x + self._learning_rate * y, theta, gradient_step
        )

        # Update CMA Parameters
        sorted_candidates = coeffs[sorted_indices]
        cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)

        # If no improvement draw randomly and re-initialize parameters
        reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
            cmaes_state
        )

        # re-sample
        key, subkey = jax.random.split(key)
        random_theta = repertoire.select(
            subkey, num_samples=1, selector=self._selector
        ).genotypes

        # update theta in case of reinit
        theta = jax.tree.map(
            lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta
        )

        # update cmaes state in case of reinit
        cmaes_state = jax.tree.map(
            lambda x, y: jnp.where(reinitialize, x, y),
            self._cma_initial_state,
            cmaes_state,
        )

        # score theta
        key, subkey = jax.random.split(key)
        _, _, extra_score = self._scoring_function(theta, subkey)

        # create new emitter state
        emitter_state = CMAMEGAState(
            theta=theta,
            theta_grads=extra_score["normalized_grads"],
            key=key,
            cmaes_state=cmaes_state,
            previous_fitnesses=repertoire.fitnesses,
        )

        return emitter_state

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size

batch_size property

Returns:
  • int

    the batch size emitted by the emitter.

__init__(scoring_function, batch_size, learning_rate, num_descriptors, centroids, sigma_g, selector=None)

Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by Fontaine et al.

Parameters:
  • scoring_function (Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]) –

    a function to score individuals, outputting fitness, descriptors and extra scores. With this emitter, the extra score contains gradients and normalized gradients.

  • batch_size (int) –

    number of solutions sampled at each iteration

  • learning_rate (float) –

    rate at which the mean of the distribution is updated.

  • num_descriptors (int) –

    number of descriptors

  • centroids (Centroid) –

    centroids of the repertoire used to store the genotypes

  • sigma_g (float) –

    standard deviation for the coefficients

Source code in qdax/core/emitters/cma_mega_emitter.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    scoring_function: Callable[
        [Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores]
    ],
    batch_size: int,
    learning_rate: float,
    num_descriptors: int,
    centroids: Centroid,
    sigma_g: float,
    selector: Optional[Selector] = None,
):
    """
    Class for the emitter of CMA Mega from "Differentiable Quality Diversity" by
    Fontaine et al.

    Args:
        scoring_function: a function to score individuals, outputting fitness,
            descriptors and extra scores. With this emitter, the extra score
            contains gradients and normalized gradients.
        batch_size: number of solutions sampled at each iteration
        learning_rate: rate at which the mean of the distribution is updated.
        num_descriptors: number of descriptors
        centroids: centroids of the repertoire used to store the genotypes
        sigma_g: standard deviation for the coefficients
    """

    self._scoring_function = scoring_function
    self._batch_size = batch_size
    self._learning_rate = learning_rate

    # weights used to update the gradient direction through a linear combination
    self._weights = jnp.expand_dims(
        jnp.log(batch_size + 0.5) - jnp.log(jnp.arange(1, batch_size + 1)), axis=-1
    )
    self._weights = self._weights / (self._weights.sum())

    # define a CMAES instance - used to update the coeffs
    self._cmaes = CMAES(
        population_size=batch_size,
        search_dim=num_descriptors + 1,
        # no need for fitness function in that specific case
        fitness_function=None,  # type: ignore
        num_best=batch_size,
        init_sigma=sigma_g,
        bias_weights=True,
        delay_eigen_decomposition=True,
    )

    self._centroids = centroids

    self._cma_initial_state = self._cmaes.init()

    self._selector = selector

emit(repertoire, emitter_state, key)

Emits new individuals. Interestingly, this method does not directly modifies individuals from the repertoire but sample from a distribution. Hence the repertoire is not used in the emit function.

Parameters:
  • repertoire (Optional[MapElitesRepertoire]) –

    a repertoire of genotypes (unused).

  • emitter_state (CMAMEGAState) –

    the state of the CMA-MEGA emitter.

  • key (RNGKey) –

    a random key to handle random operations.

Returns:
  • Tuple[Genotype, ExtraScores]

    New genotypes.

Source code in qdax/core/emitters/cma_mega_emitter.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def emit(  # type: ignore
    self,
    repertoire: Optional[MapElitesRepertoire],
    emitter_state: CMAMEGAState,
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """
    Emits new individuals. Interestingly, this method does not directly modifies
    individuals from the repertoire but sample from a distribution. Hence the
    repertoire is not used in the emit function.

    Args:
        repertoire: a repertoire of genotypes (unused).
        emitter_state: the state of the CMA-MEGA emitter.
        key: a random key to handle random operations.

    Returns:
        New genotypes.
    """

    # retrieve elements from the emitter state
    theta = jnp.nan_to_num(emitter_state.theta)
    cmaes_state = emitter_state.cmaes_state

    # get grads - remove nan and first dimension
    grads = jnp.nan_to_num(emitter_state.theta_grads.squeeze(axis=0))

    # Draw random coefficients - use the emitter state key
    coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=key)

    # make sure the fitness coefficient is positive
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
    update_grad = coeffs @ grads.T

    # Compute new candidates
    new_thetas = jax.tree.map(lambda x, y: x + y, theta, update_grad)

    return new_thetas, {}

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initializes the CMA-MEGA emitter.

Parameters:
  • genotypes (Genotype) –

    initial genotypes to add to the grid.

  • key (RNGKey) –

    a random key to handle stochastic operations.

Returns:
  • CMAMEGAState

    The initial state of the emitter.

Source code in qdax/core/emitters/cma_mega_emitter.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def init(
    self,
    key: RNGKey,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> CMAMEGAState:
    """
    Initializes the CMA-MEGA emitter.


    Args:
        genotypes: initial genotypes to add to the grid.
        key: a random key to handle stochastic operations.

    Returns:
        The initial state of the emitter.
    """

    # define init theta as 0
    theta = jax.tree.map(
        lambda x: jnp.zeros_like(x[:1, ...]),
        genotypes,
    )

    # score it
    _, _, extra_score = self._scoring_function(theta, key)
    theta_grads = extra_score["normalized_grads"]

    # Initialize repertoire with default values
    num_centroids = self._centroids.shape[0]
    default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

    # return the initial state
    key, subkey = jax.random.split(key)
    emitter_state = CMAMEGAState(
        theta=theta,
        theta_grads=theta_grads,
        key=subkey,
        cmaes_state=self._cma_initial_state,
        previous_fitnesses=default_fitnesses,
    )
    return emitter_state

state_update(emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores=None)

Updates the CMA-MEGA emitter state.

Note: in order to recover the coeffs that where used to sample the genotypes, we reuse the emitter state's random key in this function.

Note: we use the update_state function from CMAES, a function that suppose that the candidates are already sorted. We do this because we have to sort them in this function anyway, in order to apply the right weights to the terms when update theta.

Parameters:
  • emitter_state (CMAMEGAState) –

    current emitter state

  • repertoire (MapElitesRepertoire) –

    the current genotypes repertoire

  • genotypes (Genotype) –

    the genotypes of the batch of emitted offspring (unused).

  • fitnesses (Fitness) –

    the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) –

    the descriptors of the emitted offspring.

  • extra_scores (Optional[ExtraScores], default: None ) –

    unused

Returns:
  • Optional[EmitterState]

    The updated emitter state.

Source code in qdax/core/emitters/cma_mega_emitter.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def state_update(  # type: ignore
    self,
    emitter_state: CMAMEGAState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: Optional[ExtraScores] = None,
) -> Optional[EmitterState]:
    """
    Updates the CMA-MEGA emitter state.

    Note: in order to recover the coeffs that where used to sample the genotypes,
    we reuse the emitter state's random key in this function.

    Note: we use the update_state function from CMAES, a function that suppose
    that the candidates are already sorted. We do this because we have to sort
    them in this function anyway, in order to apply the right weights to the
    terms when update theta.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring (unused).
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: unused

    Returns:
        The updated emitter state.
    """
    key = emitter_state.key

    # retrieve elements from the emitter state
    cmaes_state = emitter_state.cmaes_state
    theta = jnp.nan_to_num(emitter_state.theta)
    grads = jnp.nan_to_num(emitter_state.theta_grads[0])

    # Update the archive and compute the improvements
    indices = get_cells_indices(descriptors, repertoire.centroids)
    improvements = (
        fitnesses - emitter_state.previous_fitnesses.squeeze(axis=1)[indices]
    )

    # condition for being a new cell
    condition = improvements == jnp.inf

    # criteria: fitness if new cell, improvement else
    ranking_criteria = jnp.where(condition, fitnesses, improvements)

    # make sure to have all the new cells first
    new_cell_offset = jnp.max(ranking_criteria) - jnp.min(ranking_criteria)

    ranking_criteria = jnp.where(
        condition, ranking_criteria + new_cell_offset, ranking_criteria
    )

    # sort indices according to the criteria
    sorted_indices = jnp.flip(jnp.argsort(ranking_criteria))

    # Draw the coeffs - reuse the emitter state key to get same coeffs
    key, subkey = jax.random.split(key)
    coeffs = self._cmaes.sample(cmaes_state=cmaes_state, key=subkey)
    # make sure the fitness coeff is positive
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))

    # get the gradients that must be applied
    update_grad = coeffs @ grads.T

    # weight terms - based on improvement rank
    gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0)

    # update theta
    theta = jax.tree.map(
        lambda x, y: x + self._learning_rate * y, theta, gradient_step
    )

    # Update CMA Parameters
    sorted_candidates = coeffs[sorted_indices]
    cmaes_state = self._cmaes.update_state(cmaes_state, sorted_candidates)

    # If no improvement draw randomly and re-initialize parameters
    reinitialize = jnp.all(improvements < 0) + self._cmaes.stop_condition(
        cmaes_state
    )

    # re-sample
    key, subkey = jax.random.split(key)
    random_theta = repertoire.select(
        subkey, num_samples=1, selector=self._selector
    ).genotypes

    # update theta in case of reinit
    theta = jax.tree.map(
        lambda x, y: jnp.where(reinitialize, x, y), random_theta, theta
    )

    # update cmaes state in case of reinit
    cmaes_state = jax.tree.map(
        lambda x, y: jnp.where(reinitialize, x, y),
        self._cma_initial_state,
        cmaes_state,
    )

    # score theta
    key, subkey = jax.random.split(key)
    _, _, extra_score = self._scoring_function(theta, subkey)

    # create new emitter state
    emitter_state = CMAMEGAState(
        theta=theta,
        theta_grads=extra_score["normalized_grads"],
        key=key,
        cmaes_state=cmaes_state,
        previous_fitnesses=repertoire.fitnesses,
    )

    return emitter_state