Objective and Measure Gradient MAP-Elites via Gradient Arborescence (OMGMEGA)

To create an instance of OMGMEGA, one need to use an instance of MAP-Elites with the OMGMEGAEmitter, detailed below.

Bases: Emitter

Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by Fontaine et al.

NOTE: in order to implement this emitter while staying in the MAPElites framework, we had to make two temporary design choices: - in the emit function, we use the same random key to sample from the genotypes and gradients repertoire, in order to get the gradients that correspond to the right genotypes. Although acceptable, this is definitely not the best coding practice and we would prefer to get rid of this in a future version. A solution that we are discussing with the development team is to decompose the sampling function of the repertoire into two phases: one sampling the indices to be sampled, the other one retrieving the corresponding elements. This would enable to reuse the indices instead of doing this double sampling. - in the state_update, we have to insert the gradients in the gradients repertoire in the same way the individuals were inserted. Once again, this is slightly unoptimal because the same addition mechanism has to be computed two times. One solution that we are discussing and that is very similar to the first solution discussed above, would be to decompose the addition mechanism in two phases: one outputting the indices at which individuals will be added, and then the actual insertion step. This would enable to reuse the same indices to add the gradients instead of having to recompute them.

The two design choices seem acceptable and enable to have OMG MEGA compatible with the current implementation of the MAPElites and MAPElitesRepertoire classes.

Our suggested solutions seem quite simple and are likely to be useful for other variants implementation. They will be further discussed with the development team and potentially added in a future version of the package.

Source code in qdax/core/emitters/omg_mega_emitter.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class OMGMEGAEmitter(Emitter):
    """
    Class for the emitter of OMG Mega from "Differentiable Quality Diversity" by
    Fontaine et al.

    NOTE: in order to implement this emitter while staying in the MAPElites
    framework, we had to make two temporary design choices:
    - in the emit function, we use the same random key to sample from the
    genotypes and gradients repertoire, in order to get the gradients that
    correspond to the right genotypes. Although acceptable, this is definitely
    not the best coding practice and we would prefer to get rid of this in a
    future version. A solution that we are discussing with the development team
    is to decompose the sampling function of the repertoire into two phases: one
    sampling the indices to be sampled, the other one retrieving the corresponding
    elements. This would enable to reuse the indices instead of doing this double
    sampling.
    - in the state_update, we have to insert the gradients in the gradients
    repertoire in the same way the individuals were inserted. Once again, this is
    slightly unoptimal because the same addition mechanism has to be computed two
    times. One solution that we are discussing and that is very similar to the first
    solution discussed above, would be to decompose the addition mechanism in two
    phases: one outputting the indices at which individuals will be added, and then
    the actual insertion step. This would enable to reuse the same indices to add
    the gradients instead of having to recompute them.

    The two design choices seem acceptable and enable to have OMG MEGA compatible
    with the current implementation of the MAPElites and MAPElitesRepertoire classes.

    Our suggested solutions seem quite simple and are likely to be useful for other
    variants implementation. They will be further discussed with the development team
    and potentially added in a future version of the package.
    """

    def __init__(
        self,
        batch_size: int,
        sigma_g: float,
        num_descriptors: int,
        centroids: Centroid,
        selector: Optional[Selector] = None,
    ):
        """Creates an instance of the OMGMEGAEmitter class.

        Args:
            batch_size: number of solutions sampled at each iteration
            sigma_g: CAUTION - square of the standard deviation for the coefficients.
                This notation can be misleading as, although it's called sigma, it
                refers to the variance and not the standard deviation.
            num_descriptors: number of descriptors
            centroids: centroids used to create the repertoire of solutions.
                This will be used to create the repertoire of gradients.
        """
        # set the mean of the coeff distribution to zero
        self._mu = jnp.zeros(num_descriptors + 1)

        # set the cov matrix to sigma * I
        self._sigma = jnp.eye(num_descriptors + 1) * sigma_g

        # define other parameters of the distribution
        self._batch_size = batch_size
        self._centroids = centroids
        self._num_descriptors = num_descriptors

        self._selector = selector

    def init(
        self,
        key: RNGKey,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> OMGMEGAEmitterState:
        """Initialises the state of the emitter. Creates an empty repertoire
        that will later contain the gradients of the individuals.

        Args:
            genotypes: The genotypes of the initial population.
            key: a random key to handle stochastic operations.

        Returns:
            The initial emitter state.
        """
        # retrieve one genotype from the population
        first_genotype = jax.tree.map(lambda x: x[0], genotypes)

        # add a dimension of size num descriptors + 1
        gradient_genotype = jax.tree.map(
            lambda x: jnp.repeat(
                jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
            ),
            first_genotype,
        )

        # create the gradients repertoire
        gradients_repertoire = MapElitesRepertoire.init_default(
            genotype=gradient_genotype, centroids=self._centroids
        )

        # get gradients out of the extra scores
        assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]

        # update the gradients repertoire
        gradients_repertoire = gradients_repertoire.add(
            gradients,
            descriptors,
            fitnesses,
            extra_scores,
        )

        return OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire)

    def emit(  # type: ignore
        self,
        repertoire: MapElitesRepertoire,
        emitter_state: OMGMEGAEmitterState,
        key: RNGKey,
    ) -> Tuple[Genotype, ExtraScores]:
        """
        OMG emitter function that samples elements in the repertoire and does a gradient
        update with random coefficients to create new candidates.

        Args:
            repertoire: current repertoire
            emitter_state: current emitter state, contains the gradients
            key: random key

        Returns:
            new_genotypes: new candidates to be added to the grid
        """
        # sample genotypes
        key, subkey = jax.random.split(key)

        size_repertoire = repertoire.fitnesses.shape[0]
        repertoire_indexes = repertoire.replace(genotypes=jnp.arange(size_repertoire))
        indexes_selected = repertoire_indexes.select(
            subkey, num_samples=self._batch_size, selector=self._selector
        ).genotypes

        genotypes = jax.tree.map(lambda x: x[indexes_selected], repertoire.genotypes)
        gradients = jax.tree.map(
            lambda x: x[indexes_selected], emitter_state.gradients_repertoire.genotypes
        )

        fitness_gradients = jax.tree.map(
            lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
        )
        descriptors_gradients = jax.tree.map(lambda x: x[:, :, 1:], gradients)

        # Normalize the gradients
        norm_fitness_gradients = jnp.linalg.norm(
            fitness_gradients, axis=1, keepdims=True
        )

        fitness_gradients = fitness_gradients / norm_fitness_gradients

        norm_descriptors_gradients = jnp.linalg.norm(
            descriptors_gradients, axis=1, keepdims=True
        )
        descriptors_gradients = descriptors_gradients / norm_descriptors_gradients

        # Draw random coefficients
        coeffs = jax.random.multivariate_normal(
            key,
            shape=(self._batch_size,),
            mean=self._mu,
            cov=self._sigma,
        )
        coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
        grads = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=-1),
            fitness_gradients,
            descriptors_gradients,
        )
        update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

        # update the genotypes
        new_genotypes = jax.tree.map(lambda x, y: x + y, genotypes, update_grad)

        return new_genotypes, {}

    def state_update(  # type: ignore
        self,
        emitter_state: OMGMEGAEmitterState,
        repertoire: MapElitesRepertoire,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        extra_scores: ExtraScores,
    ) -> OMGMEGAEmitterState:
        """Update the gradients repertoire to have the right gradients.

        NOTE: see discussion in the class docstrings to see how this could
        be improved.

        Args:
            emitter_state: current emitter state
            repertoire: the current genotypes repertoire
            genotypes: the genotypes of the batch of emitted offspring.
            fitnesses: the fitnesses of the batch of emitted offspring.
            descriptors: the descriptors of the emitted offspring.
            extra_scores: a dictionary with other values outputted by the
                scoring function.

        Returns:
            The modified emitter state.
        """

        # get gradients out of the extra scores
        assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
        gradients = extra_scores["gradients"]

        # update the gradients repertoire
        gradients_repertoire = emitter_state.gradients_repertoire.add(
            gradients,
            descriptors,
            fitnesses,
            extra_scores,
        )

        return emitter_state.replace(  # type: ignore
            gradients_repertoire=gradients_repertoire
        )

    @property
    def batch_size(self) -> int:
        """
        Returns:
            the batch size emitted by the emitter.
        """
        return self._batch_size

batch_size property

Returns:
  • int

    the batch size emitted by the emitter.

__init__(batch_size, sigma_g, num_descriptors, centroids, selector=None)

Creates an instance of the OMGMEGAEmitter class.

Parameters:
  • batch_size (int) –

    number of solutions sampled at each iteration

  • sigma_g (float) –

    CAUTION - square of the standard deviation for the coefficients. This notation can be misleading as, although it's called sigma, it refers to the variance and not the standard deviation.

  • num_descriptors (int) –

    number of descriptors

  • centroids (Centroid) –

    centroids used to create the repertoire of solutions. This will be used to create the repertoire of gradients.

Source code in qdax/core/emitters/omg_mega_emitter.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    batch_size: int,
    sigma_g: float,
    num_descriptors: int,
    centroids: Centroid,
    selector: Optional[Selector] = None,
):
    """Creates an instance of the OMGMEGAEmitter class.

    Args:
        batch_size: number of solutions sampled at each iteration
        sigma_g: CAUTION - square of the standard deviation for the coefficients.
            This notation can be misleading as, although it's called sigma, it
            refers to the variance and not the standard deviation.
        num_descriptors: number of descriptors
        centroids: centroids used to create the repertoire of solutions.
            This will be used to create the repertoire of gradients.
    """
    # set the mean of the coeff distribution to zero
    self._mu = jnp.zeros(num_descriptors + 1)

    # set the cov matrix to sigma * I
    self._sigma = jnp.eye(num_descriptors + 1) * sigma_g

    # define other parameters of the distribution
    self._batch_size = batch_size
    self._centroids = centroids
    self._num_descriptors = num_descriptors

    self._selector = selector

emit(repertoire, emitter_state, key)

OMG emitter function that samples elements in the repertoire and does a gradient update with random coefficients to create new candidates.

Parameters:
  • repertoire (MapElitesRepertoire) –

    current repertoire

  • emitter_state (OMGMEGAEmitterState) –

    current emitter state, contains the gradients

  • key (RNGKey) –

    random key

Returns:
  • new_genotypes( Tuple[Genotype, ExtraScores] ) –

    new candidates to be added to the grid

Source code in qdax/core/emitters/omg_mega_emitter.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def emit(  # type: ignore
    self,
    repertoire: MapElitesRepertoire,
    emitter_state: OMGMEGAEmitterState,
    key: RNGKey,
) -> Tuple[Genotype, ExtraScores]:
    """
    OMG emitter function that samples elements in the repertoire and does a gradient
    update with random coefficients to create new candidates.

    Args:
        repertoire: current repertoire
        emitter_state: current emitter state, contains the gradients
        key: random key

    Returns:
        new_genotypes: new candidates to be added to the grid
    """
    # sample genotypes
    key, subkey = jax.random.split(key)

    size_repertoire = repertoire.fitnesses.shape[0]
    repertoire_indexes = repertoire.replace(genotypes=jnp.arange(size_repertoire))
    indexes_selected = repertoire_indexes.select(
        subkey, num_samples=self._batch_size, selector=self._selector
    ).genotypes

    genotypes = jax.tree.map(lambda x: x[indexes_selected], repertoire.genotypes)
    gradients = jax.tree.map(
        lambda x: x[indexes_selected], emitter_state.gradients_repertoire.genotypes
    )

    fitness_gradients = jax.tree.map(
        lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients
    )
    descriptors_gradients = jax.tree.map(lambda x: x[:, :, 1:], gradients)

    # Normalize the gradients
    norm_fitness_gradients = jnp.linalg.norm(
        fitness_gradients, axis=1, keepdims=True
    )

    fitness_gradients = fitness_gradients / norm_fitness_gradients

    norm_descriptors_gradients = jnp.linalg.norm(
        descriptors_gradients, axis=1, keepdims=True
    )
    descriptors_gradients = descriptors_gradients / norm_descriptors_gradients

    # Draw random coefficients
    coeffs = jax.random.multivariate_normal(
        key,
        shape=(self._batch_size,),
        mean=self._mu,
        cov=self._sigma,
    )
    coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0]))
    grads = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=-1),
        fitness_gradients,
        descriptors_gradients,
    )
    update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1)

    # update the genotypes
    new_genotypes = jax.tree.map(lambda x, y: x + y, genotypes, update_grad)

    return new_genotypes, {}

init(key, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Initialises the state of the emitter. Creates an empty repertoire that will later contain the gradients of the individuals.

Parameters:
  • genotypes (Genotype) –

    The genotypes of the initial population.

  • key (RNGKey) –

    a random key to handle stochastic operations.

Returns:
  • OMGMEGAEmitterState

    The initial emitter state.

Source code in qdax/core/emitters/omg_mega_emitter.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def init(
    self,
    key: RNGKey,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
    """Initialises the state of the emitter. Creates an empty repertoire
    that will later contain the gradients of the individuals.

    Args:
        genotypes: The genotypes of the initial population.
        key: a random key to handle stochastic operations.

    Returns:
        The initial emitter state.
    """
    # retrieve one genotype from the population
    first_genotype = jax.tree.map(lambda x: x[0], genotypes)

    # add a dimension of size num descriptors + 1
    gradient_genotype = jax.tree.map(
        lambda x: jnp.repeat(
            jnp.expand_dims(x, axis=-1), repeats=self._num_descriptors + 1, axis=-1
        ),
        first_genotype,
    )

    # create the gradients repertoire
    gradients_repertoire = MapElitesRepertoire.init_default(
        genotype=gradient_genotype, centroids=self._centroids
    )

    # get gradients out of the extra scores
    assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
    gradients = extra_scores["gradients"]

    # update the gradients repertoire
    gradients_repertoire = gradients_repertoire.add(
        gradients,
        descriptors,
        fitnesses,
        extra_scores,
    )

    return OMGMEGAEmitterState(gradients_repertoire=gradients_repertoire)

state_update(emitter_state, repertoire, genotypes, fitnesses, descriptors, extra_scores)

Update the gradients repertoire to have the right gradients.

NOTE: see discussion in the class docstrings to see how this could be improved.

Parameters:
  • emitter_state (OMGMEGAEmitterState) –

    current emitter state

  • repertoire (MapElitesRepertoire) –

    the current genotypes repertoire

  • genotypes (Genotype) –

    the genotypes of the batch of emitted offspring.

  • fitnesses (Fitness) –

    the fitnesses of the batch of emitted offspring.

  • descriptors (Descriptor) –

    the descriptors of the emitted offspring.

  • extra_scores (ExtraScores) –

    a dictionary with other values outputted by the scoring function.

Returns:
  • OMGMEGAEmitterState

    The modified emitter state.

Source code in qdax/core/emitters/omg_mega_emitter.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
def state_update(  # type: ignore
    self,
    emitter_state: OMGMEGAEmitterState,
    repertoire: MapElitesRepertoire,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    extra_scores: ExtraScores,
) -> OMGMEGAEmitterState:
    """Update the gradients repertoire to have the right gradients.

    NOTE: see discussion in the class docstrings to see how this could
    be improved.

    Args:
        emitter_state: current emitter state
        repertoire: the current genotypes repertoire
        genotypes: the genotypes of the batch of emitted offspring.
        fitnesses: the fitnesses of the batch of emitted offspring.
        descriptors: the descriptors of the emitted offspring.
        extra_scores: a dictionary with other values outputted by the
            scoring function.

    Returns:
        The modified emitter state.
    """

    # get gradients out of the extra scores
    assert "gradients" in extra_scores.keys(), "Missing gradients or wrong key"
    gradients = extra_scores["gradients"]

    # update the gradients repertoire
    gradients_repertoire = emitter_state.gradients_repertoire.add(
        gradients,
        descriptors,
        fitnesses,
        extra_scores,
    )

    return emitter_state.replace(  # type: ignore
        gradients_repertoire=gradients_repertoire
    )