Containers

DominatedNoveltyRepertoire

Bases: GARepertoire

Repertoire that keeps the individuals with highest dominated novelty.

Parameters:
  • genotypes

    population genotypes with shape (population_size, ...)

  • fitnesses

    population fitnesses with shape (population_size, fitness_dim)

  • descriptors

    population descriptors with shape (population_size, D)

  • k

    number of neighbors for novelty and dominated novelty

  • extra_scores

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores

    keys of the extra scores to store in the repertoire

Source code in qdax/core/containers/dns_repertoire.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
class DominatedNoveltyRepertoire(GARepertoire):
    """Repertoire that keeps the individuals with highest dominated novelty.

    Args:
        genotypes: population genotypes with shape (population_size, ...)
        fitnesses: population fitnesses with shape (population_size, fitness_dim)
        descriptors: population descriptors with shape (population_size, D)
        k: number of neighbors for novelty and dominated novelty
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire
    """

    descriptors: Descriptor
    k: int = flax.struct.field(pytree_node=False)

    @jax.jit
    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> DominatedNoveltyRepertoire:
        """Add a batch and keep the top individuals by dominated novelty.

        Parents and offsprings are gathered and only the population_size
        best according to dominated novelty are kept.
        """

        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        batch_of_fitnesses = jnp.reshape(
            batch_of_fitnesses, (batch_of_fitnesses.shape[0], 1)
        )

        # Gather candidates
        candidates_genotypes = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )
        candidates_fitnesses = jnp.concatenate(
            (self.fitnesses, batch_of_fitnesses), axis=0
        )
        candidates_descriptors = jnp.concatenate(
            (self.descriptors, batch_of_descriptors), axis=0
        )
        candidates_extra_scores = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.extra_scores,
            filtered_batch_of_extra_scores,
        )

        # Compute dominated novelty
        _, dominated_novelty = _novelty_and_dominated_novelty(
            fitness=candidates_fitnesses[:, 0],
            descriptor=candidates_descriptors,
            novelty_k=self.k,
            dominated_novelty_k=self.k,
        )

        # Use dominated novelty as meta-fitness, invalid individuals get -inf
        valid = candidates_fitnesses[:, 0] != -jnp.inf
        meta_fitness = jnp.where(valid, dominated_novelty, -jnp.inf)

        # Select survivors
        indices = jnp.argsort(meta_fitness)[::-1]
        survivor_indices = indices[: self.size]

        new_genotypes = jax.tree.map(
            lambda x: x[survivor_indices], candidates_genotypes
        )
        new_fitnesses = candidates_fitnesses[survivor_indices]
        new_descriptors = candidates_descriptors[survivor_indices]
        new_extra_scores = jax.tree.map(
            lambda x: x[survivor_indices], candidates_extra_scores
        )

        return self.replace(  # type: ignore
            genotypes=new_genotypes,
            fitnesses=new_fitnesses,
            descriptors=new_descriptors,
            extra_scores=new_extra_scores,
        )

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        population_size: int,
        k: int,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> DominatedNoveltyRepertoire:
        """Initialize the repertoire and add the first batch.

        Args:
            genotypes: first batch of genotypes (batch_size, ...)
            fitnesses: fitnesses of shape (batch_size, fitness_dim)
            descriptors: descriptors of shape (batch_size, num_descriptors)
            population_size: maximum number of individuals kept
            k: number of neighbors for novelty metrics
            extra_scores: extra scores of the first batch
            keys_extra_scores: keys of extra scores to store
        """

        if extra_scores is None:
            extra_scores = {}

        # retrieve one genotype and one extra score prototype
        first_genotype = jax.tree.map(lambda x: x[0], genotypes)
        first_extra_scores = jax.tree.map(lambda x: x[0], extra_scores)

        # create a repertoire with default values
        repertoire = cls.init_default(
            genotype=first_genotype,
            descriptor_dim=descriptors.shape[-1],
            population_size=population_size,
            one_extra_score=first_extra_scores,
            keys_extra_scores=keys_extra_scores,
            k=k,
        )

        # add initial population to the repertoire
        return repertoire.add(  # type: ignore
            genotypes, descriptors, fitnesses, extra_scores
        )

    @classmethod
    def init_default(
        cls,
        genotype: Genotype,
        descriptor_dim: int,
        population_size: int,
        one_extra_score: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        k: int = 15,
    ) -> DominatedNoveltyRepertoire:
        """Create a DNS repertoire with default values.

        Args:
            genotype: a representative genotype PyTree (leaf shapes define storage).
            descriptor_dim: number of descriptor dimensions.
            population_size: maximum number of individuals kept.
            one_extra_score: a representative extra score PyTree to size buffers.
            keys_extra_scores: keys of extra scores to store in the repertoire.
            k: number of neighbors for novelty metrics.

        Returns:
            A repertoire filled with default values.
        """
        if one_extra_score is None:
            one_extra_score = {}

        one_extra_score = {
            key: value
            for key, value in one_extra_score.items()
            if key in keys_extra_scores
        }

        # default fitness is -inf
        default_fitnesses = -jnp.inf * jnp.ones(shape=(population_size, 1))

        # default genotypes is all zeros
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape, dtype=x.dtype),
            genotype,
        )

        # default descriptors is NaN (uninitialized)
        default_descriptors = jnp.full(
            shape=(population_size, descriptor_dim), fill_value=jnp.nan
        )

        # default extra scores buffers
        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape, dtype=x.dtype),
            one_extra_score,
        )

        return cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
            k=k,
        )

add(batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Add a batch and keep the top individuals by dominated novelty.

Parents and offsprings are gathered and only the population_size best according to dominated novelty are kept.

Source code in qdax/core/containers/dns_repertoire.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@jax.jit
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> DominatedNoveltyRepertoire:
    """Add a batch and keep the top individuals by dominated novelty.

    Parents and offsprings are gathered and only the population_size
    best according to dominated novelty are kept.
    """

    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    batch_of_fitnesses = jnp.reshape(
        batch_of_fitnesses, (batch_of_fitnesses.shape[0], 1)
    )

    # Gather candidates
    candidates_genotypes = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )
    candidates_fitnesses = jnp.concatenate(
        (self.fitnesses, batch_of_fitnesses), axis=0
    )
    candidates_descriptors = jnp.concatenate(
        (self.descriptors, batch_of_descriptors), axis=0
    )
    candidates_extra_scores = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.extra_scores,
        filtered_batch_of_extra_scores,
    )

    # Compute dominated novelty
    _, dominated_novelty = _novelty_and_dominated_novelty(
        fitness=candidates_fitnesses[:, 0],
        descriptor=candidates_descriptors,
        novelty_k=self.k,
        dominated_novelty_k=self.k,
    )

    # Use dominated novelty as meta-fitness, invalid individuals get -inf
    valid = candidates_fitnesses[:, 0] != -jnp.inf
    meta_fitness = jnp.where(valid, dominated_novelty, -jnp.inf)

    # Select survivors
    indices = jnp.argsort(meta_fitness)[::-1]
    survivor_indices = indices[: self.size]

    new_genotypes = jax.tree.map(
        lambda x: x[survivor_indices], candidates_genotypes
    )
    new_fitnesses = candidates_fitnesses[survivor_indices]
    new_descriptors = candidates_descriptors[survivor_indices]
    new_extra_scores = jax.tree.map(
        lambda x: x[survivor_indices], candidates_extra_scores
    )

    return self.replace(  # type: ignore
        genotypes=new_genotypes,
        fitnesses=new_fitnesses,
        descriptors=new_descriptors,
        extra_scores=new_extra_scores,
    )

init(genotypes, fitnesses, descriptors, population_size, k, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initialize the repertoire and add the first batch.

Parameters:
  • genotypes (Genotype) –

    first batch of genotypes (batch_size, ...)

  • fitnesses (Fitness) –

    fitnesses of shape (batch_size, fitness_dim)

  • descriptors (Descriptor) –

    descriptors of shape (batch_size, num_descriptors)

  • population_size (int) –

    maximum number of individuals kept

  • k (int) –

    number of neighbors for novelty metrics

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of the first batch

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of extra scores to store

Source code in qdax/core/containers/dns_repertoire.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    population_size: int,
    k: int,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> DominatedNoveltyRepertoire:
    """Initialize the repertoire and add the first batch.

    Args:
        genotypes: first batch of genotypes (batch_size, ...)
        fitnesses: fitnesses of shape (batch_size, fitness_dim)
        descriptors: descriptors of shape (batch_size, num_descriptors)
        population_size: maximum number of individuals kept
        k: number of neighbors for novelty metrics
        extra_scores: extra scores of the first batch
        keys_extra_scores: keys of extra scores to store
    """

    if extra_scores is None:
        extra_scores = {}

    # retrieve one genotype and one extra score prototype
    first_genotype = jax.tree.map(lambda x: x[0], genotypes)
    first_extra_scores = jax.tree.map(lambda x: x[0], extra_scores)

    # create a repertoire with default values
    repertoire = cls.init_default(
        genotype=first_genotype,
        descriptor_dim=descriptors.shape[-1],
        population_size=population_size,
        one_extra_score=first_extra_scores,
        keys_extra_scores=keys_extra_scores,
        k=k,
    )

    # add initial population to the repertoire
    return repertoire.add(  # type: ignore
        genotypes, descriptors, fitnesses, extra_scores
    )

init_default(genotype, descriptor_dim, population_size, one_extra_score=None, keys_extra_scores=(), k=15) classmethod

Create a DNS repertoire with default values.

Parameters:
  • genotype (Genotype) –

    a representative genotype PyTree (leaf shapes define storage).

  • descriptor_dim (int) –

    number of descriptor dimensions.

  • population_size (int) –

    maximum number of individuals kept.

  • one_extra_score (Optional[ExtraScores], default: None ) –

    a representative extra score PyTree to size buffers.

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of extra scores to store in the repertoire.

  • k (int, default: 15 ) –

    number of neighbors for novelty metrics.

Returns:
Source code in qdax/core/containers/dns_repertoire.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@classmethod
def init_default(
    cls,
    genotype: Genotype,
    descriptor_dim: int,
    population_size: int,
    one_extra_score: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    k: int = 15,
) -> DominatedNoveltyRepertoire:
    """Create a DNS repertoire with default values.

    Args:
        genotype: a representative genotype PyTree (leaf shapes define storage).
        descriptor_dim: number of descriptor dimensions.
        population_size: maximum number of individuals kept.
        one_extra_score: a representative extra score PyTree to size buffers.
        keys_extra_scores: keys of extra scores to store in the repertoire.
        k: number of neighbors for novelty metrics.

    Returns:
        A repertoire filled with default values.
    """
    if one_extra_score is None:
        one_extra_score = {}

    one_extra_score = {
        key: value
        for key, value in one_extra_score.items()
        if key in keys_extra_scores
    }

    # default fitness is -inf
    default_fitnesses = -jnp.inf * jnp.ones(shape=(population_size, 1))

    # default genotypes is all zeros
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape, dtype=x.dtype),
        genotype,
    )

    # default descriptors is NaN (uninitialized)
    default_descriptors = jnp.full(
        shape=(population_size, descriptor_dim), fill_value=jnp.nan
    )

    # default extra scores buffers
    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape, dtype=x.dtype),
        one_extra_score,
    )

    return cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
        k=k,
    )

GARepertoire

Bases: Repertoire

Class for a simple repertoire for a simple genetic algorithm.

Parameters:
  • genotypes

    a PyTree containing the genotypes of the individuals in the population. Each leaf has the shape (population_size, num_features).

  • fitnesses

    an array containing the fitness of the individuals in the population. With shape (population_size, fitness_dim). The implementation of GARepertoire was thought for the case where fitness_dim equals 1 but the class can be herited and rules adapted for cases where fitness_dim is greater than 1.

  • extra_scores

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores

    keys of the extra scores to store in the repertoire

Source code in qdax/core/containers/ga_repertoire.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class GARepertoire(Repertoire):
    """Class for a simple repertoire for a simple genetic
    algorithm.

    Args:
        genotypes: a PyTree containing the genotypes of the
            individuals in the population. Each leaf has the
            shape (population_size, num_features).
        fitnesses: an array containing the fitness of the individuals
            in the population. With shape (population_size, fitness_dim).
            The implementation of GARepertoire was thought for the case
            where fitness_dim equals 1 but the class can be herited and
            rules adapted for cases where fitness_dim is greater than 1.
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire
    """

    genotypes: Genotype
    fitnesses: Fitness
    extra_scores: ExtraScores
    keys_extra_scores: Tuple[str, ...] = flax.struct.field(
        pytree_node=False,
    )

    @property
    def size(self) -> int:
        """Gives the size of the population."""
        first_leaf = jax.tree.leaves(self.genotypes)[0]
        return int(first_leaf.shape[0])

    def select(
        self,
        key: RNGKey,
        num_samples: int,
        selector: Optional[Selector[GARepertoireT]] = None,
    ) -> GARepertoireT:
        if selector is None:
            selector = UniformSelector(select_with_replacement=True)
        repertoire = selector.select(self, key, num_samples)
        return repertoire

    def filter_extra_scores(self, extra_scores: ExtraScores) -> ExtraScores:
        filtered_extra_scores = {
            key: value
            for key, value in extra_scores.items()
            if key in self.keys_extra_scores
        }
        return filtered_extra_scores

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> GARepertoire:
        """Implements the repertoire addition rules.

        Parents and offsprings are gathered and only the population_size
        bests are kept. The others are killed.

        Args:
            batch_of_genotypes: new genotypes that we try to add.
            batch_of_fitnesses: fitness of those new genotypes.

        Returns:
            The updated repertoire.
        """
        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        # gather individuals and fitnesses
        candidates = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )
        candidates_fitnesses = jnp.concatenate(
            (self.fitnesses, batch_of_fitnesses), axis=0
        )

        # sort by fitnesses
        indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]

        # keep only the best ones
        survivor_indices = indices[: self.size]

        # keep only the best ones
        new_candidates = jax.tree.map(lambda x: x[survivor_indices], candidates)
        new_extra_scores = jax.tree.map(
            lambda x: x[survivor_indices], filtered_batch_of_extra_scores
        )
        new_repertoire = self.replace(
            genotypes=new_candidates,
            fitnesses=candidates_fitnesses[survivor_indices],
            extra_scores=new_extra_scores,
        )

        return new_repertoire  # type: ignore

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        population_size: int,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> GARepertoire:
        """Initializes the repertoire.

        Start with default values and adds a first batch of genotypes
        to the repertoire.

        Args:
            genotypes: first batch of genotypes
            fitnesses: corresponding fitnesses
            population_size: size of the population we want to evolve
            extra_scores: extra scores resulting from the evaluation of the genotypes
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            An initial repertoire.
        """

        if extra_scores is None:
            extra_scores = {}

        # create default fitnesses
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(population_size, fitnesses.shape[-1])
        )

        # create default genotypes
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
        )

        # create default extra scores
        filtered_extra_scores = {
            key: value
            for key, value in extra_scores.items()
            if key in keys_extra_scores
        }

        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]),
            filtered_extra_scores,
        )

        # create an initial repertoire with those default values
        repertoire = cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

        new_repertoire = repertoire.add(genotypes, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

size property

Gives the size of the population.

add(batch_of_genotypes, batch_of_fitnesses, batch_of_extra_scores=None)

Implements the repertoire addition rules.

Parents and offsprings are gathered and only the population_size bests are kept. The others are killed.

Parameters:
  • batch_of_genotypes (Genotype) –

    new genotypes that we try to add.

  • batch_of_fitnesses (Fitness) –

    fitness of those new genotypes.

Returns:
Source code in qdax/core/containers/ga_repertoire.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> GARepertoire:
    """Implements the repertoire addition rules.

    Parents and offsprings are gathered and only the population_size
    bests are kept. The others are killed.

    Args:
        batch_of_genotypes: new genotypes that we try to add.
        batch_of_fitnesses: fitness of those new genotypes.

    Returns:
        The updated repertoire.
    """
    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    # gather individuals and fitnesses
    candidates = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )
    candidates_fitnesses = jnp.concatenate(
        (self.fitnesses, batch_of_fitnesses), axis=0
    )

    # sort by fitnesses
    indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]

    # keep only the best ones
    survivor_indices = indices[: self.size]

    # keep only the best ones
    new_candidates = jax.tree.map(lambda x: x[survivor_indices], candidates)
    new_extra_scores = jax.tree.map(
        lambda x: x[survivor_indices], filtered_batch_of_extra_scores
    )
    new_repertoire = self.replace(
        genotypes=new_candidates,
        fitnesses=candidates_fitnesses[survivor_indices],
        extra_scores=new_extra_scores,
    )

    return new_repertoire  # type: ignore

init(genotypes, fitnesses, population_size, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initializes the repertoire.

Start with default values and adds a first batch of genotypes to the repertoire.

Parameters:
  • genotypes (Genotype) –

    first batch of genotypes

  • fitnesses (Fitness) –

    corresponding fitnesses

  • population_size (int) –

    size of the population we want to evolve

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/ga_repertoire.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    population_size: int,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> GARepertoire:
    """Initializes the repertoire.

    Start with default values and adds a first batch of genotypes
    to the repertoire.

    Args:
        genotypes: first batch of genotypes
        fitnesses: corresponding fitnesses
        population_size: size of the population we want to evolve
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        An initial repertoire.
    """

    if extra_scores is None:
        extra_scores = {}

    # create default fitnesses
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(population_size, fitnesses.shape[-1])
    )

    # create default genotypes
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
    )

    # create default extra scores
    filtered_extra_scores = {
        key: value
        for key, value in extra_scores.items()
        if key in keys_extra_scores
    }

    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]),
        filtered_extra_scores,
    )

    # create an initial repertoire with those default values
    repertoire = cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

    new_repertoire = repertoire.add(genotypes, fitnesses, extra_scores)

    return new_repertoire  # type: ignore

MELSRepertoire

Bases: MapElitesRepertoire

Class for the repertoire in MAP-Elites Low-Spread.

This class inherits from MapElitesRepertoire. In addition to the stored data in MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire also maintains an array of spreads. We overload the add, and init_default methods of MapElitesRepertoire.

Refer to Mace 2023 for more info on MAP-Elites Low-Spread: https://dl.acm.org/doi/abs/10.1145/3583131.3590433

Parameters:
  • genotypes

    a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses

    an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors

    an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • centroids

    an array that contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors).

  • spreads

    an array that contains the spread of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

Source code in qdax/core/containers/mels_repertoire.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
class MELSRepertoire(MapElitesRepertoire):
    """Class for the repertoire in MAP-Elites Low-Spread.

    This class inherits from MapElitesRepertoire. In addition to the stored data in
    MapElitesRepertoire (genotypes, fitnesses, descriptors, centroids), this repertoire
    also maintains an array of spreads. We overload the add, and
    init_default methods of MapElitesRepertoire.

    Refer to Mace 2023 for more info on MAP-Elites Low-Spread:
    https://dl.acm.org/doi/abs/10.1145/3583131.3590433

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple JAX array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        centroids: an array that contains the centroids of the tessellation. The array
            shape is (num_centroids, num_descriptors).
        spreads: an array that contains the spread of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
    """

    spreads: Spread

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MELSRepertoire:
        """
        Add a batch of elements to the repertoire.

        The key difference between this method and the default add() in
        MapElitesRepertoire is that it expects each individual to be evaluated
        `num_samples` times, resulting in `num_samples` fitnesses and
        `num_samples` descriptors per individual.

        If multiple individuals may be added to a single cell, this method will
        arbitrarily pick one -- the exact choice depends on the implementation of
        jax.at[].set(), which can be non-deterministic:
        https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
        We do not currently check if one of the multiple individuals dominates the
        others (dominate means that the individual has both highest fitness and lowest
        spread among the individuals for that cell).

        If `num_samples` is only 1, the spreads will default to 0.

        Args:
            batch_of_genotypes: a batch of genotypes to be added to the repertoire.
                Similarly to the self.genotypes argument, this is a PyTree in which
                the leaves have a shape (batch_size, num_features)
            batch_of_descriptors: an array that contains the descriptors of the
                aforementioned genotypes over all evals. Its shape is
                (batch_size, num_samples, num_descriptors). Note that we "aggregate"
                descriptors by finding the most frequent cell of each individual. Thus,
                the actual descriptors stored in the repertoire are just the coordinates
                of the centroid of the most frequent cell.
            batch_of_fitnesses: an array that contains the fitnesses of the
                aforementioned genotypes over all evals. Its shape is (batch_size,
                num_samples)
            batch_of_extra_scores: unused tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated repertoire.
        """

        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        batch_size, num_samples = batch_of_fitnesses.shape

        # Compute indices/cells of all descriptors.
        batch_of_all_indices = get_cells_indices(
            batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
        ).reshape((batch_size, num_samples))

        # Compute most frequent cell of each solution.
        batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]

        # Compute dispersion / spread. The dispersion is set to zero if
        # num_samples is 1.
        batch_of_spreads = jax.lax.cond(
            num_samples == 1,
            lambda desc: jnp.zeros(batch_size),
            lambda desc: jax.vmap(_dispersion)(
                desc.reshape((batch_size, num_samples, -1))
            ),
            batch_of_descriptors,
        )
        batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)

        # Compute canonical descriptors as the descriptor of the centroid of the most
        # frequent cell. Note that this line redefines the earlier batch_of_descriptors.
        batch_of_descriptors = jnp.take_along_axis(
            self.centroids, batch_of_indices, axis=0
        )

        # Compute canonical fitnesses as the average fitness.
        #
        # Shape: (batch_size, 1)
        batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)

        num_centroids = self.centroids.shape[0]

        # get current repertoire fitnesses and spreads
        current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)

        repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
        current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)

        # get addition condition
        addition_condition_fitness = batch_of_fitnesses > current_fitnesses
        addition_condition_spread = batch_of_spreads <= current_spreads
        addition_condition = jnp.logical_and(
            addition_condition_fitness, addition_condition_spread
        )

        # assign fake position when relevant : num_centroids is out of
        # bound
        batch_of_indices = jnp.where(
            addition_condition, batch_of_indices, num_centroids
        )

        # create new repertoire
        new_repertoire_genotypes = jax.tree.map(
            lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_fitnesses,
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_descriptors
        )
        new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_spreads.squeeze(axis=-1)
        )

        # update extra scores
        new_extra_scores = jax.tree.map(
            lambda repertoire_scores, new_scores: repertoire_scores.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_scores),
            self.extra_scores,
            filtered_batch_of_extra_scores,
        )

        return MELSRepertoire(
            genotypes=new_repertoire_genotypes,
            fitnesses=new_fitnesses,
            extra_scores=new_extra_scores,
            keys_extra_scores=self.keys_extra_scores,
            descriptors=new_descriptors,
            centroids=self.centroids,
            spreads=new_spreads,
        )

    @classmethod
    def init_default(
        cls,
        genotype: Genotype,
        centroids: Centroid,
        one_extra_score: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
    ) -> MELSRepertoire:
        """Initialize a MAP-Elites Low-Spread repertoire with an initial population of
        genotypes. Requires the definition of centroids that can be computed with any
        method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MELS, so
        it can be called easily called from other modules.

        Args:
            genotype: the typical genotype that will be stored.
            centroids: the centroids of the repertoire.
            extra_scores: extra scores to store in the repertoire
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            A repertoire filled with default values.
        """
        if one_extra_score is None:
            one_extra_score = {}

        one_extra_score = {
            key: value
            for key, value in one_extra_score.items()
            if key in keys_extra_scores
        }

        # get number of centroids
        num_centroids = centroids.shape[0]

        # default fitness is -inf
        default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

        # default genotypes is all 0
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            genotype,
        )

        # default descriptor is all zeros
        default_descriptors = jnp.zeros_like(centroids)

        # default spread is inf so that any spread will be less
        default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)

        # default extra scores is empty dict
        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            one_extra_score,
        )

        return cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
            spreads=default_spreads,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

add(batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Add a batch of elements to the repertoire.

The key difference between this method and the default add() in MapElitesRepertoire is that it expects each individual to be evaluated num_samples times, resulting in num_samples fitnesses and num_samples descriptors per individual.

If multiple individuals may be added to a single cell, this method will arbitrarily pick one -- the exact choice depends on the implementation of jax.at[].set(), which can be non-deterministic: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html We do not currently check if one of the multiple individuals dominates the others (dominate means that the individual has both highest fitness and lowest spread among the individuals for that cell).

If num_samples is only 1, the spreads will default to 0.

Parameters:
  • batch_of_genotypes (Genotype) –

    a batch of genotypes to be added to the repertoire. Similarly to the self.genotypes argument, this is a PyTree in which the leaves have a shape (batch_size, num_features)

  • batch_of_descriptors (Descriptor) –

    an array that contains the descriptors of the aforementioned genotypes over all evals. Its shape is (batch_size, num_samples, num_descriptors). Note that we "aggregate" descriptors by finding the most frequent cell of each individual. Thus, the actual descriptors stored in the repertoire are just the coordinates of the centroid of the most frequent cell.

  • batch_of_fitnesses (Fitness) –

    an array that contains the fitnesses of the aforementioned genotypes over all evals. Its shape is (batch_size, num_samples)

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    unused tree that contains the extra_scores of aforementioned genotypes.

Returns:
Source code in qdax/core/containers/mels_repertoire.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MELSRepertoire:
    """
    Add a batch of elements to the repertoire.

    The key difference between this method and the default add() in
    MapElitesRepertoire is that it expects each individual to be evaluated
    `num_samples` times, resulting in `num_samples` fitnesses and
    `num_samples` descriptors per individual.

    If multiple individuals may be added to a single cell, this method will
    arbitrarily pick one -- the exact choice depends on the implementation of
    jax.at[].set(), which can be non-deterministic:
    https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
    We do not currently check if one of the multiple individuals dominates the
    others (dominate means that the individual has both highest fitness and lowest
    spread among the individuals for that cell).

    If `num_samples` is only 1, the spreads will default to 0.

    Args:
        batch_of_genotypes: a batch of genotypes to be added to the repertoire.
            Similarly to the self.genotypes argument, this is a PyTree in which
            the leaves have a shape (batch_size, num_features)
        batch_of_descriptors: an array that contains the descriptors of the
            aforementioned genotypes over all evals. Its shape is
            (batch_size, num_samples, num_descriptors). Note that we "aggregate"
            descriptors by finding the most frequent cell of each individual. Thus,
            the actual descriptors stored in the repertoire are just the coordinates
            of the centroid of the most frequent cell.
        batch_of_fitnesses: an array that contains the fitnesses of the
            aforementioned genotypes over all evals. Its shape is (batch_size,
            num_samples)
        batch_of_extra_scores: unused tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated repertoire.
    """

    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    batch_size, num_samples = batch_of_fitnesses.shape

    # Compute indices/cells of all descriptors.
    batch_of_all_indices = get_cells_indices(
        batch_of_descriptors.reshape(batch_size * num_samples, -1), self.centroids
    ).reshape((batch_size, num_samples))

    # Compute most frequent cell of each solution.
    batch_of_indices = jax.vmap(_mode)(batch_of_all_indices)[:, None]

    # Compute dispersion / spread. The dispersion is set to zero if
    # num_samples is 1.
    batch_of_spreads = jax.lax.cond(
        num_samples == 1,
        lambda desc: jnp.zeros(batch_size),
        lambda desc: jax.vmap(_dispersion)(
            desc.reshape((batch_size, num_samples, -1))
        ),
        batch_of_descriptors,
    )
    batch_of_spreads = jnp.expand_dims(batch_of_spreads, axis=-1)

    # Compute canonical descriptors as the descriptor of the centroid of the most
    # frequent cell. Note that this line redefines the earlier batch_of_descriptors.
    batch_of_descriptors = jnp.take_along_axis(
        self.centroids, batch_of_indices, axis=0
    )

    # Compute canonical fitnesses as the average fitness.
    #
    # Shape: (batch_size, 1)
    batch_of_fitnesses = batch_of_fitnesses.mean(axis=-1, keepdims=True)

    num_centroids = self.centroids.shape[0]

    # get current repertoire fitnesses and spreads
    current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)

    repertoire_spreads = jnp.expand_dims(self.spreads, axis=-1)
    current_spreads = jnp.take_along_axis(repertoire_spreads, batch_of_indices, 0)

    # get addition condition
    addition_condition_fitness = batch_of_fitnesses > current_fitnesses
    addition_condition_spread = batch_of_spreads <= current_spreads
    addition_condition = jnp.logical_and(
        addition_condition_fitness, addition_condition_spread
    )

    # assign fake position when relevant : num_centroids is out of
    # bound
    batch_of_indices = jnp.where(
        addition_condition, batch_of_indices, num_centroids
    )

    # create new repertoire
    new_repertoire_genotypes = jax.tree.map(
        lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_fitnesses,
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_descriptors
    )
    new_spreads = self.spreads.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_spreads.squeeze(axis=-1)
    )

    # update extra scores
    new_extra_scores = jax.tree.map(
        lambda repertoire_scores, new_scores: repertoire_scores.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_scores),
        self.extra_scores,
        filtered_batch_of_extra_scores,
    )

    return MELSRepertoire(
        genotypes=new_repertoire_genotypes,
        fitnesses=new_fitnesses,
        extra_scores=new_extra_scores,
        keys_extra_scores=self.keys_extra_scores,
        descriptors=new_descriptors,
        centroids=self.centroids,
        spreads=new_spreads,
    )

init_default(genotype, centroids, one_extra_score=None, keys_extra_scores=()) classmethod

Initialize a MAP-Elites Low-Spread repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MELS, so it can be called easily called from other modules.

Parameters:
  • genotype (Genotype) –

    the typical genotype that will be stored.

  • centroids (Centroid) –

    the centroids of the repertoire.

  • extra_scores

    extra scores to store in the repertoire

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/mels_repertoire.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
@classmethod
def init_default(
    cls,
    genotype: Genotype,
    centroids: Centroid,
    one_extra_score: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
) -> MELSRepertoire:
    """Initialize a MAP-Elites Low-Spread repertoire with an initial population of
    genotypes. Requires the definition of centroids that can be computed with any
    method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MELS, so
    it can be called easily called from other modules.

    Args:
        genotype: the typical genotype that will be stored.
        centroids: the centroids of the repertoire.
        extra_scores: extra scores to store in the repertoire
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        A repertoire filled with default values.
    """
    if one_extra_score is None:
        one_extra_score = {}

    one_extra_score = {
        key: value
        for key, value in one_extra_score.items()
        if key in keys_extra_scores
    }

    # get number of centroids
    num_centroids = centroids.shape[0]

    # default fitness is -inf
    default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

    # default genotypes is all 0
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        genotype,
    )

    # default descriptor is all zeros
    default_descriptors = jnp.zeros_like(centroids)

    # default spread is inf so that any spread will be less
    default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)

    # default extra scores is empty dict
    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        one_extra_score,
    )

    return cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
        spreads=default_spreads,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

MOMERepertoire

Bases: MapElitesRepertoire

Class for the repertoire in Multi Objective Map Elites

This class inherits from MAPElitesRepertoire. The stored data is the same: genotypes, fitnesses, descriptors, centroids.

The shape of genotypes is (in the case where it's an array): (num_centroids, pareto_front_length, genotype_dim). When the genotypes is a PyTree, the two first dimensions are the same but the third will depend on the leafs.

The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)

The shape of descriptors and centroids are: (num_centroids, num_descriptors, pareto_front_length).

Source code in qdax/core/containers/mome_repertoire.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
class MOMERepertoire(MapElitesRepertoire):
    """Class for the repertoire in Multi Objective Map Elites

    This class inherits from MAPElitesRepertoire. The stored data
    is the same: genotypes, fitnesses, descriptors, centroids.

    The shape of genotypes is (in the case where it's an array):
    (num_centroids, pareto_front_length, genotype_dim).
    When the genotypes is a PyTree, the two first dimensions are the same
    but the third will depend on the leafs.

    The shape of fitnesses is: (num_centroids, pareto_front_length, num_criteria)

    The shape of descriptors and centroids are:
    (num_centroids, num_descriptors, pareto_front_length).
    """

    @property
    def repertoire_capacity(self) -> int:
        """Returns the maximum number of solutions the repertoire can
        contain which corresponds to the number of cells times the
        maximum pareto front length.

        Returns:
            The repertoire capacity.
        """
        first_leaf = jax.tree.leaves(self.genotypes)[0]
        return int(first_leaf.shape[0] * first_leaf.shape[1])

    def select(
        self,
        key: RNGKey,
        num_samples: int,
        selector: Optional[Selector[MOMERepertoireT]] = None,
    ) -> MOMERepertoireT:
        if selector is None:
            selector = MOMEUniformSelector()
        repertoire = selector.select(self, key, num_samples)
        return repertoire

    def _update_masked_pareto_front(
        self,
        pareto_front_fitnesses: ParetoFront[Fitness],
        pareto_front_genotypes: ParetoFront[Genotype],
        pareto_front_descriptors: ParetoFront[Descriptor],
        pareto_front_extra_scores: ParetoFront[ExtraScores],
        mask: Mask,
        new_batch_of_fitnesses: Fitness,
        new_batch_of_genotypes: Genotype,
        new_batch_of_descriptors: Descriptor,
        new_batch_of_extra_scores: ExtraScores,
        new_mask: Mask,
    ) -> Tuple[
        ParetoFront[Fitness],
        ParetoFront[Genotype],
        ParetoFront[Descriptor],
        ParetoFront[ExtraScores],
        Mask,
    ]:
        """Takes a fixed size pareto front, its mask and new points to add.
        Returns updated front and mask.

        Args:
            pareto_front_fitnesses: fitness of the pareto front
            pareto_front_genotypes: corresponding genotypes
            pareto_front_descriptors: corresponding descriptors
            pareto_front_extra_scores: corresponding extra scores
            mask: mask of the front, to hide void parts
            new_batch_of_fitnesses: new batch of fitness that is considered
                to be added to the pareto front
            new_batch_of_genotypes: corresponding genotypes
            new_batch_of_descriptors: corresponding descriptors
            new_batch_of_extra_scores: corresponding extra scores
            new_mask: corresponding mask (no one is masked)

        Returns:
            The updated pareto front.
        """
        # get dimensions
        batch_size = new_batch_of_fitnesses.shape[0]
        num_criteria = new_batch_of_fitnesses.shape[1]

        pareto_front_len = pareto_front_fitnesses.shape[0]  # type: ignore

        descriptors_dim = new_batch_of_descriptors.shape[1]

        # gather all data
        cat_mask = jnp.concatenate([mask, new_mask], axis=-1)
        cat_fitnesses = jnp.concatenate(
            [pareto_front_fitnesses, new_batch_of_fitnesses], axis=0
        )
        cat_genotypes = jax.tree.map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            pareto_front_genotypes,
            new_batch_of_genotypes,
        )
        cat_descriptors = jnp.concatenate(
            [pareto_front_descriptors, new_batch_of_descriptors], axis=0
        )
        cat_extra_scores = jax.tree.map(
            lambda x, y: jnp.concatenate([x, y], axis=0),
            pareto_front_extra_scores,
            new_batch_of_extra_scores,
        )

        # get new front
        cat_bool_front = compute_masked_pareto_front(
            batch_of_criteria=cat_fitnesses, mask=cat_mask
        )

        # get corresponding indices
        indices = (
            jnp.arange(start=0, stop=pareto_front_len + batch_size) * cat_bool_front
        )
        indices = indices + ~cat_bool_front * (batch_size + pareto_front_len - 1)
        indices = jnp.sort(indices)

        # get new fitness, genotypes and descriptors
        new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0)
        new_front_genotypes = jax.tree.map(
            lambda x: jnp.take(x, indices, axis=0), cat_genotypes
        )
        new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0)
        new_front_extra_scores = jax.tree.map(
            lambda x: jnp.take(x, indices, axis=0), cat_extra_scores
        )

        # compute new mask
        num_front_elements = jnp.sum(cat_bool_front)
        new_mask_indices = jnp.arange(start=0, stop=batch_size + pareto_front_len)
        new_mask_indices = (num_front_elements - new_mask_indices) > 0

        new_mask = jnp.where(
            new_mask_indices,
            jnp.ones(shape=batch_size + pareto_front_len, dtype=bool),
            jnp.zeros(shape=batch_size + pareto_front_len, dtype=bool),
        )

        fitness_mask = jnp.repeat(
            jnp.expand_dims(new_mask, axis=-1), num_criteria, axis=-1
        )
        new_front_fitness = new_front_fitness * fitness_mask

        front_size = len(pareto_front_fitnesses)  # type: ignore
        new_front_fitness = new_front_fitness[:front_size, :]

        new_front_genotypes = jax.tree.map(
            lambda x: x * new_mask_indices[0], new_front_genotypes
        )
        new_front_genotypes = jax.tree.map(
            lambda x: x[:front_size], new_front_genotypes
        )

        descriptors_mask = jnp.repeat(
            jnp.expand_dims(new_mask, axis=-1), descriptors_dim, axis=-1
        )
        new_front_descriptors = new_front_descriptors * descriptors_mask
        new_front_descriptors = new_front_descriptors[:front_size, :]

        new_front_extra_scores = jax.tree.map(
            lambda x: x * new_mask_indices[0], new_front_extra_scores
        )
        new_front_extra_scores = jax.tree.map(
            lambda x: x[:front_size], new_front_extra_scores
        )

        new_mask = ~new_mask[:front_size]

        return (
            new_front_fitness,
            new_front_genotypes,
            new_front_descriptors,
            new_front_extra_scores,
            new_mask,
        )

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MOMERepertoire:
        """Insert a batch of elements in the repertoire.

        Shape of the batch_of_genotypes (if an array):
        (batch_size, genotypes_dim)
        Shape of the batch_of_descriptors: (batch_size, num_descriptors)
        Shape of the batch_of_fitnesses: (batch_size, num_criteria)

        Args:
            batch_of_genotypes: a batch of genotypes that we are trying to
                insert into the repertoire.
            batch_of_descriptors: the descriptors of the genotypes we are
                trying to add to the repertoire.
            batch_of_fitnesses: the fitnesses of the genotypes we are trying
                to add to the repertoire.
            batch_of_extra_scores: unused tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated repertoire with potential new individuals.
        """
        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        # get the indices that corresponds to the descriptors in the repertoire
        batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

        def _add_one(
            carry: MOMERepertoire,
            data: Tuple[Genotype, Descriptor, Fitness, ExtraScores, jax.Array],
        ) -> Tuple[MOMERepertoire, Any]:
            # unwrap data
            genotype, descriptors, fitness, extra_scores, index = data

            index = index.astype(jnp.int32)

            # get current repertoire cell data
            cell_genotype = jax.tree.map(lambda x: x[index][0], carry.genotypes)
            cell_fitness = carry.fitnesses[index][0]
            cell_descriptor = carry.descriptors[index][0]
            cell_extra_scores = jax.tree.map(lambda x: x[index][0], carry.extra_scores)
            cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

            new_genotypes = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), genotype)

            # update pareto front
            (
                cell_fitness,
                cell_genotype,  # new pf for cell
                cell_descriptor,
                cell_extra_scores,
                cell_mask,
            ) = self._update_masked_pareto_front(
                pareto_front_fitnesses=cell_fitness,
                pareto_front_genotypes=cell_genotype,
                pareto_front_descriptors=cell_descriptor,
                pareto_front_extra_scores=cell_extra_scores,
                mask=cell_mask,
                new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
                new_batch_of_genotypes=new_genotypes,
                new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
                new_batch_of_extra_scores=jax.tree.map(
                    lambda x: jnp.expand_dims(x, axis=0), extra_scores
                ),
                new_mask=jnp.zeros(shape=(1,), dtype=bool),
            )

            # update cell fitness
            cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)

            # update grid
            new_genotypes = jax.tree.map(
                lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
            )
            new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
            new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
            new_extra_scores = jax.tree.map(
                lambda x, y: x.at[index].set(y), carry.extra_scores, cell_extra_scores
            )
            carry = carry.replace(  # type: ignore
                genotypes=new_genotypes,
                descriptors=new_descriptors,
                fitnesses=new_fitnesses,
                extra_scores=new_extra_scores,
            )

            # return new grid
            return carry, ()

        # scan the addition operation for all the data
        self, _ = jax.lax.scan(
            _add_one,
            self,
            (
                batch_of_genotypes,
                batch_of_descriptors,
                batch_of_fitnesses,
                filtered_batch_of_extra_scores,
                batch_of_indices,
            ),
        )

        return self

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        centroids: Centroid,
        pareto_front_max_length: int,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> MOMERepertoire:
        """
        Initialize a Multi Objective Map-Elites repertoire with an initial population
        of genotypes. Requires the definition of centroids that can be computed with
        any method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so it can
        be called easily called from other modules.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape:
                (batch_size, num_criteria)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            pareto_front_max_length: maximum size of the pareto fronts
            extra_scores: unused extra_scores of the initial genotypes
            keys_extra_scores: keys of the extra_scores of the initial genotypes

        Returns:
            An initialized MAP-Elite repertoire
        """

        warnings.warn(
            (
                "This type of repertoire does not store the extra scores "
                "computed by the scoring function"
            ),
            stacklevel=2,
        )

        if extra_scores is None:
            extra_scores = {}

        filtered_extra_scores = {key: extra_scores[key] for key in keys_extra_scores}

        # get dimensions
        num_criteria = fitnesses.shape[1]
        num_descriptors = descriptors.shape[1]
        num_centroids = centroids.shape[0]

        # create default values
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(num_centroids, pareto_front_max_length, num_criteria)
        )
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(
                shape=(
                    num_centroids,
                    pareto_front_max_length,
                )
                + x.shape[1:]
            ),
            genotypes,
        )
        default_descriptors = jnp.zeros(
            shape=(num_centroids, pareto_front_max_length, num_descriptors)
        )

        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(
                shape=(num_centroids, pareto_front_max_length, *x.shape[1:])
            ),
            filtered_extra_scores,
        )

        # create repertoire with default values
        repertoire = MOMERepertoire(  # type: ignore
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

        # add first batch of individuals in the repertoire
        new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

    @jax.jit
    def compute_global_pareto_front(
        self,
    ) -> Tuple[ParetoFront[Fitness], Mask]:
        """Merge all the pareto fronts of the MOME repertoire into a single one
        called global pareto front.

        Returns:
            The pareto front and its mask.
        """
        fitnesses = jnp.concatenate(self.fitnesses, axis=0)
        mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
        pareto_mask = compute_masked_pareto_front(fitnesses, mask)
        pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)

        return pareto_front, pareto_mask

repertoire_capacity property

Returns the maximum number of solutions the repertoire can contain which corresponds to the number of cells times the maximum pareto front length.

Returns:
  • int

    The repertoire capacity.

add(batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Insert a batch of elements in the repertoire.

Shape of the batch_of_genotypes (if an array): (batch_size, genotypes_dim) Shape of the batch_of_descriptors: (batch_size, num_descriptors) Shape of the batch_of_fitnesses: (batch_size, num_criteria)

Parameters:
  • batch_of_genotypes (Genotype) –

    a batch of genotypes that we are trying to insert into the repertoire.

  • batch_of_descriptors (Descriptor) –

    the descriptors of the genotypes we are trying to add to the repertoire.

  • batch_of_fitnesses (Fitness) –

    the fitnesses of the genotypes we are trying to add to the repertoire.

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    unused tree that contains the extra_scores of aforementioned genotypes.

Returns:
  • MOMERepertoire

    The updated repertoire with potential new individuals.

Source code in qdax/core/containers/mome_repertoire.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
    """Insert a batch of elements in the repertoire.

    Shape of the batch_of_genotypes (if an array):
    (batch_size, genotypes_dim)
    Shape of the batch_of_descriptors: (batch_size, num_descriptors)
    Shape of the batch_of_fitnesses: (batch_size, num_criteria)

    Args:
        batch_of_genotypes: a batch of genotypes that we are trying to
            insert into the repertoire.
        batch_of_descriptors: the descriptors of the genotypes we are
            trying to add to the repertoire.
        batch_of_fitnesses: the fitnesses of the genotypes we are trying
            to add to the repertoire.
        batch_of_extra_scores: unused tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated repertoire with potential new individuals.
    """
    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    # get the indices that corresponds to the descriptors in the repertoire
    batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

    def _add_one(
        carry: MOMERepertoire,
        data: Tuple[Genotype, Descriptor, Fitness, ExtraScores, jax.Array],
    ) -> Tuple[MOMERepertoire, Any]:
        # unwrap data
        genotype, descriptors, fitness, extra_scores, index = data

        index = index.astype(jnp.int32)

        # get current repertoire cell data
        cell_genotype = jax.tree.map(lambda x: x[index][0], carry.genotypes)
        cell_fitness = carry.fitnesses[index][0]
        cell_descriptor = carry.descriptors[index][0]
        cell_extra_scores = jax.tree.map(lambda x: x[index][0], carry.extra_scores)
        cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)

        new_genotypes = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), genotype)

        # update pareto front
        (
            cell_fitness,
            cell_genotype,  # new pf for cell
            cell_descriptor,
            cell_extra_scores,
            cell_mask,
        ) = self._update_masked_pareto_front(
            pareto_front_fitnesses=cell_fitness,
            pareto_front_genotypes=cell_genotype,
            pareto_front_descriptors=cell_descriptor,
            pareto_front_extra_scores=cell_extra_scores,
            mask=cell_mask,
            new_batch_of_fitnesses=jnp.expand_dims(fitness, axis=0),
            new_batch_of_genotypes=new_genotypes,
            new_batch_of_descriptors=jnp.expand_dims(descriptors, axis=0),
            new_batch_of_extra_scores=jax.tree.map(
                lambda x: jnp.expand_dims(x, axis=0), extra_scores
            ),
            new_mask=jnp.zeros(shape=(1,), dtype=bool),
        )

        # update cell fitness
        cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)

        # update grid
        new_genotypes = jax.tree.map(
            lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
        )
        new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
        new_descriptors = carry.descriptors.at[index].set(cell_descriptor)
        new_extra_scores = jax.tree.map(
            lambda x, y: x.at[index].set(y), carry.extra_scores, cell_extra_scores
        )
        carry = carry.replace(  # type: ignore
            genotypes=new_genotypes,
            descriptors=new_descriptors,
            fitnesses=new_fitnesses,
            extra_scores=new_extra_scores,
        )

        # return new grid
        return carry, ()

    # scan the addition operation for all the data
    self, _ = jax.lax.scan(
        _add_one,
        self,
        (
            batch_of_genotypes,
            batch_of_descriptors,
            batch_of_fitnesses,
            filtered_batch_of_extra_scores,
            batch_of_indices,
        ),
    )

    return self

compute_global_pareto_front()

Merge all the pareto fronts of the MOME repertoire into a single one called global pareto front.

Returns:
  • Tuple[ParetoFront[Fitness], Mask]

    The pareto front and its mask.

Source code in qdax/core/containers/mome_repertoire.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
@jax.jit
def compute_global_pareto_front(
    self,
) -> Tuple[ParetoFront[Fitness], Mask]:
    """Merge all the pareto fronts of the MOME repertoire into a single one
    called global pareto front.

    Returns:
        The pareto front and its mask.
    """
    fitnesses = jnp.concatenate(self.fitnesses, axis=0)
    mask = jnp.any(fitnesses == -jnp.inf, axis=-1)
    pareto_mask = compute_masked_pareto_front(fitnesses, mask)
    pareto_front = fitnesses - jnp.inf * (~jnp.array([pareto_mask, pareto_mask]).T)

    return pareto_front, pareto_mask

init(genotypes, fitnesses, descriptors, centroids, pareto_front_max_length, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initialize a Multi Objective Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) –

    fitness of the initial genotypes of shape: (batch_size, num_criteria)

  • descriptors (Descriptor) –

    descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • centroids (Centroid) –

    tessellation centroids of shape (batch_size, num_descriptors)

  • pareto_front_max_length (int) –

    maximum size of the pareto fronts

  • extra_scores (Optional[ExtraScores], default: None ) –

    unused extra_scores of the initial genotypes

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra_scores of the initial genotypes

Returns:
Source code in qdax/core/containers/mome_repertoire.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    centroids: Centroid,
    pareto_front_max_length: int,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> MOMERepertoire:
    """
    Initialize a Multi Objective Map-Elites repertoire with an initial population
    of genotypes. Requires the definition of centroids that can be computed with
    any method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so it can
    be called easily called from other modules.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape:
            (batch_size, num_criteria)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        pareto_front_max_length: maximum size of the pareto fronts
        extra_scores: unused extra_scores of the initial genotypes
        keys_extra_scores: keys of the extra_scores of the initial genotypes

    Returns:
        An initialized MAP-Elite repertoire
    """

    warnings.warn(
        (
            "This type of repertoire does not store the extra scores "
            "computed by the scoring function"
        ),
        stacklevel=2,
    )

    if extra_scores is None:
        extra_scores = {}

    filtered_extra_scores = {key: extra_scores[key] for key in keys_extra_scores}

    # get dimensions
    num_criteria = fitnesses.shape[1]
    num_descriptors = descriptors.shape[1]
    num_centroids = centroids.shape[0]

    # create default values
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(num_centroids, pareto_front_max_length, num_criteria)
    )
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(
            shape=(
                num_centroids,
                pareto_front_max_length,
            )
            + x.shape[1:]
        ),
        genotypes,
    )
    default_descriptors = jnp.zeros(
        shape=(num_centroids, pareto_front_max_length, num_descriptors)
    )

    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(
            shape=(num_centroids, pareto_front_max_length, *x.shape[1:])
        ),
        filtered_extra_scores,
    )

    # create repertoire with default values
    repertoire = MOMERepertoire(  # type: ignore
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

    # add first batch of individuals in the repertoire
    new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    return new_repertoire  # type: ignore

MapElitesRepertoire

Bases: GARepertoire

Class for the repertoire in Map Elites.

Parameters:
  • genotypes

    a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses

    an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors

    an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • centroids

    an array that contains the centroids of the tessellation. The array shape is (num_centroids, num_descriptors).

  • extra_scores

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores

    keys of the extra scores to store in the repertoire

Source code in qdax/core/containers/mapelites_repertoire.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
class MapElitesRepertoire(GARepertoire):
    """Class for the repertoire in Map Elites.

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple JAX array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        centroids: an array that contains the centroids of the tessellation. The array
            shape is (num_centroids, num_descriptors).
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire
    """

    descriptors: Descriptor
    centroids: Centroid

    def select(
        self,
        key: RNGKey,
        num_samples: int,
        selector: Optional[Selector[MapElitesRepertoireT]] = None,
    ) -> MapElitesRepertoireT:
        if selector is None:
            selector = UniformSelector(select_with_replacement=True)
        repertoire = selector.select(self, key, num_samples)
        return repertoire

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> MapElitesRepertoire:
        """
        Add a batch of elements to the repertoire.

        Args:
            batch_of_genotypes: a batch of genotypes to be added to the repertoire.
                Similarly to the self.genotypes argument, this is a PyTree in which
                the leaves have a shape (batch_size, num_features)
            batch_of_descriptors: an array that contains the descriptors of the
                aforementioned genotypes. Its shape is (batch_size, num_descriptors)
            batch_of_fitnesses: an array that contains the fitnesses of the
                aforementioned genotypes. Its shape is (batch_size,)
            batch_of_extra_scores: tree that contains the extra_scores of
                aforementioned genotypes.

        Returns:
            The updated MAP-Elites repertoire.
        """
        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

        num_centroids = self.centroids.shape[0]
        batch_of_fitnesses = jnp.reshape(
            batch_of_fitnesses, (batch_of_descriptors.shape[0], 1)
        )

        # get fitness segment max
        best_fitnesses = jax.ops.segment_max(
            batch_of_fitnesses,
            batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
            num_segments=num_centroids,
        )

        cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

        # put dominated fitness to -jnp.inf
        batch_of_fitnesses = jnp.where(
            batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf
        )

        # get addition condition
        current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)
        addition_condition = batch_of_fitnesses > current_fitnesses

        # assign fake position when relevant : num_centroids is out of bound
        batch_of_indices = jnp.where(
            addition_condition, batch_of_indices, num_centroids
        )

        # create new repertoire
        new_repertoire_genotypes = jax.tree.map(
            lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_fitnesses
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
            batch_of_descriptors
        )

        # update extra scores
        new_extra_scores = jax.tree.map(
            lambda repertoire_scores, new_scores: repertoire_scores.at[
                batch_of_indices.squeeze(axis=-1)
            ].set(new_scores),
            self.extra_scores,
            filtered_batch_of_extra_scores,
        )

        return MapElitesRepertoire(
            genotypes=new_repertoire_genotypes,
            fitnesses=new_fitnesses,
            descriptors=new_descriptors,
            centroids=self.centroids,
            extra_scores=new_extra_scores,
            keys_extra_scores=self.keys_extra_scores,
        )

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        centroids: Centroid,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> MapElitesRepertoire:
        """
        Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so it can
        be called easily called from other modules.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape (batch_size,)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            centroids: tessellation centroids of shape (batch_size, num_descriptors)
            extra_scores: extra scores of the initial genotypes
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            an initialized MAP-Elite repertoire
        """

        if extra_scores is None:
            extra_scores = {}

        extra_scores = {
            key: value
            for key, value in extra_scores.items()
            if key in keys_extra_scores
        }

        # retrieve one genotype from the population
        first_genotype = jax.tree.map(lambda x: x[0], genotypes)
        first_extra_scores = jax.tree.map(lambda x: x[0], extra_scores)

        # create a repertoire with default values
        repertoire = cls.init_default(
            genotype=first_genotype,
            centroids=centroids,
            one_extra_score=first_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

        # add initial population to the repertoire
        new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

    @classmethod
    def init_default(
        cls,
        genotype: Genotype,
        centroids: Centroid,
        one_extra_score: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
    ) -> MapElitesRepertoire:
        """Initialize a Map-Elites repertoire with an initial population of
        genotypes. Requires the definition of centroids that can be computed
        with any method such as CVT or Euclidean mapping.

        Note: this function has been kept outside of the object MapElites, so
        it can be called easily called from other modules.

        Args:
            genotype: the typical genotype that will be stored.
            centroids: the centroids of the repertoire
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            A repertoire filled with default values.
        """
        if one_extra_score is None:
            one_extra_score = {}

        one_extra_score = {
            key: value
            for key, value in one_extra_score.items()
            if key in keys_extra_scores
        }

        # get number of centroids
        num_centroids = centroids.shape[0]

        # default fitness is -inf
        default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

        # default genotypes is all 0
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            genotype,
        )

        # default descriptor is all zeros
        default_descriptors = jnp.zeros_like(centroids)

        # default extra scores is empty dict
        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
            one_extra_score,
        )

        return cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            centroids=centroids,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

add(batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Add a batch of elements to the repertoire.

Parameters:
  • batch_of_genotypes (Genotype) –

    a batch of genotypes to be added to the repertoire. Similarly to the self.genotypes argument, this is a PyTree in which the leaves have a shape (batch_size, num_features)

  • batch_of_descriptors (Descriptor) –

    an array that contains the descriptors of the aforementioned genotypes. Its shape is (batch_size, num_descriptors)

  • batch_of_fitnesses (Fitness) –

    an array that contains the fitnesses of the aforementioned genotypes. Its shape is (batch_size,)

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    tree that contains the extra_scores of aforementioned genotypes.

Returns:
Source code in qdax/core/containers/mapelites_repertoire.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MapElitesRepertoire:
    """
    Add a batch of elements to the repertoire.

    Args:
        batch_of_genotypes: a batch of genotypes to be added to the repertoire.
            Similarly to the self.genotypes argument, this is a PyTree in which
            the leaves have a shape (batch_size, num_features)
        batch_of_descriptors: an array that contains the descriptors of the
            aforementioned genotypes. Its shape is (batch_size, num_descriptors)
        batch_of_fitnesses: an array that contains the fitnesses of the
            aforementioned genotypes. Its shape is (batch_size,)
        batch_of_extra_scores: tree that contains the extra_scores of
            aforementioned genotypes.

    Returns:
        The updated MAP-Elites repertoire.
    """
    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

    num_centroids = self.centroids.shape[0]
    batch_of_fitnesses = jnp.reshape(
        batch_of_fitnesses, (batch_of_descriptors.shape[0], 1)
    )

    # get fitness segment max
    best_fitnesses = jax.ops.segment_max(
        batch_of_fitnesses,
        batch_of_indices.astype(jnp.int32).squeeze(axis=-1),
        num_segments=num_centroids,
    )

    cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

    # put dominated fitness to -jnp.inf
    batch_of_fitnesses = jnp.where(
        batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf
    )

    # get addition condition
    current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)
    addition_condition = batch_of_fitnesses > current_fitnesses

    # assign fake position when relevant : num_centroids is out of bound
    batch_of_indices = jnp.where(
        addition_condition, batch_of_indices, num_centroids
    )

    # create new repertoire
    new_repertoire_genotypes = jax.tree.map(
        lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_fitnesses
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze(axis=-1)].set(
        batch_of_descriptors
    )

    # update extra scores
    new_extra_scores = jax.tree.map(
        lambda repertoire_scores, new_scores: repertoire_scores.at[
            batch_of_indices.squeeze(axis=-1)
        ].set(new_scores),
        self.extra_scores,
        filtered_batch_of_extra_scores,
    )

    return MapElitesRepertoire(
        genotypes=new_repertoire_genotypes,
        fitnesses=new_fitnesses,
        descriptors=new_descriptors,
        centroids=self.centroids,
        extra_scores=new_extra_scores,
        keys_extra_scores=self.keys_extra_scores,
    )

init(genotypes, fitnesses, descriptors, centroids, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) –

    fitness of the initial genotypes of shape (batch_size,)

  • descriptors (Descriptor) –

    descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • centroids (Centroid) –

    tessellation centroids of shape (batch_size, num_descriptors)

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of the initial genotypes

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/mapelites_repertoire.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    centroids: Centroid,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> MapElitesRepertoire:
    """
    Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so it can
    be called easily called from other modules.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape (batch_size,)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        centroids: tessellation centroids of shape (batch_size, num_descriptors)
        extra_scores: extra scores of the initial genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        an initialized MAP-Elite repertoire
    """

    if extra_scores is None:
        extra_scores = {}

    extra_scores = {
        key: value
        for key, value in extra_scores.items()
        if key in keys_extra_scores
    }

    # retrieve one genotype from the population
    first_genotype = jax.tree.map(lambda x: x[0], genotypes)
    first_extra_scores = jax.tree.map(lambda x: x[0], extra_scores)

    # create a repertoire with default values
    repertoire = cls.init_default(
        genotype=first_genotype,
        centroids=centroids,
        one_extra_score=first_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

    # add initial population to the repertoire
    new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

    return new_repertoire  # type: ignore

init_default(genotype, centroids, one_extra_score=None, keys_extra_scores=()) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Note: this function has been kept outside of the object MapElites, so it can be called easily called from other modules.

Parameters:
  • genotype (Genotype) –

    the typical genotype that will be stored.

  • centroids (Centroid) –

    the centroids of the repertoire

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/mapelites_repertoire.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@classmethod
def init_default(
    cls,
    genotype: Genotype,
    centroids: Centroid,
    one_extra_score: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
) -> MapElitesRepertoire:
    """Initialize a Map-Elites repertoire with an initial population of
    genotypes. Requires the definition of centroids that can be computed
    with any method such as CVT or Euclidean mapping.

    Note: this function has been kept outside of the object MapElites, so
    it can be called easily called from other modules.

    Args:
        genotype: the typical genotype that will be stored.
        centroids: the centroids of the repertoire
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        A repertoire filled with default values.
    """
    if one_extra_score is None:
        one_extra_score = {}

    one_extra_score = {
        key: value
        for key, value in one_extra_score.items()
        if key in keys_extra_scores
    }

    # get number of centroids
    num_centroids = centroids.shape[0]

    # default fitness is -inf
    default_fitnesses = -jnp.inf * jnp.ones(shape=(num_centroids, 1))

    # default genotypes is all 0
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        genotype,
    )

    # default descriptor is all zeros
    default_descriptors = jnp.zeros_like(centroids)

    # default extra scores is empty dict
    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
        one_extra_score,
    )

    return cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        centroids=centroids,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

NSGA2Repertoire

Bases: GARepertoire

Repertoire used for the NSGA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, sample and init.

Source code in qdax/core/containers/nsga2_repertoire.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class NSGA2Repertoire(GARepertoire):
    """Repertoire used for the NSGA2 algorithm.

    Inherits from the GARepertoire. The data stored are the genotypes
    and there fitness. Several functions are inherited from GARepertoire,
    including size, sample and init.
    """

    def _compute_crowding_distances(
        self, fitnesses: Fitness, mask: jax.Array
    ) -> jax.Array:
        """Compute crowding distances.

        The crowding distance is the Manhatten Distance in the objective
        space. This is used to rank individuals in the addition function.

        Args:
            fitnesses: fitnesses of the considered individuals. Here,
                fitness are vectors as we are doing multi-objective
                optimization.
            mask: a vector to mask values.

        Returns:
            The crowding distances.
        """
        # Retrieve only non masked solutions
        num_solutions = fitnesses.shape[0]
        num_objective = fitnesses.shape[1]
        if num_solutions <= 2:
            return jnp.array([jnp.inf] * num_solutions)

        else:
            # Sort solutions on each objective
            mask_dist = jnp.column_stack([mask] * fitnesses.shape[1])
            score_amplitude = jnp.max(fitnesses, axis=0) - jnp.min(fitnesses, axis=0)
            dist_fitnesses = (
                fitnesses + 3 * score_amplitude * jnp.ones_like(fitnesses) * mask_dist
            )
            sorted_index = jnp.argsort(dist_fitnesses, axis=0)
            srt_fitnesses = fitnesses[sorted_index, jnp.arange(num_objective)]

            # Calculate the norm for each objective - set to NaN if all values are equal
            norm = jnp.max(srt_fitnesses, axis=0) - jnp.min(srt_fitnesses, axis=0)

            # get the distances
            dists = jnp.vstack(
                [srt_fitnesses, jnp.full(num_objective, jnp.inf)]
            ) - jnp.vstack([jnp.full(num_objective, -jnp.inf), srt_fitnesses])

            # Prepare the distance to last and next vectors
            dist_to_last, dist_to_next = dists, dists
            dist_to_last = dists[:-1] / norm
            dist_to_next = dists[1:] / norm

            # Sum up the distances and reorder
            j = jnp.argsort(sorted_index, axis=0)
            crowding_distances = (
                jnp.sum(
                    (
                        dist_to_last[j, jnp.arange(num_objective)]
                        + dist_to_next[j, jnp.arange(num_objective)]
                    ),
                    axis=1,
                )
                / num_objective
            )

            return crowding_distances

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> NSGA2Repertoire:
        """Implements the repertoire addition rules.

        The population is sorted in successive pareto front. The first one
        is the global pareto front. The second one is the pareto front of the
        population where the first pareto front has been removed, etc...

        The successive pareto fronts are kept until the moment where adding a
        full pareto front would exceed the population size.

        To decide the survival of this pareto front, a crowding distance is
        computed in order to keep individuals that are spread in this last pareto
        front. Hence, the individuals with the biggest crowding distances are
        added until the population size is reached.

        Args:
            batch_of_genotypes: new genotypes that we try to add.
            batch_of_fitnesses: fitness of those new genotypes.
            batch_of_extra_scores: extra scores of those new genotypes.

        Returns:
            The updated repertoire.
        """

        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        # All the candidates
        candidates = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )

        candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

        first_leaf = jax.tree.leaves(candidates)[0]
        num_candidates = first_leaf.shape[0]

        def compute_current_front(
            val: Tuple[jax.Array, jax.Array]
        ) -> Tuple[jax.Array, jax.Array]:
            """Body function for the while loop. Computes the successive
            pareto fronts in the data.

            Args:
                val: Value passed through the while loop. Here, it is
                    a tuple containing two values. The indexes of all
                    solutions to keep and the indexes of the last
                    computed front.

            Returns:
                The updated values to pass through the while loop. Updated
                number of solutions and updated front indexes.
            """
            to_keep_index, _ = val

            # mask the individual that are already kept
            front_index = compute_masked_pareto_front(
                candidate_fitnesses, mask=to_keep_index
            )

            # Add new indexes
            to_keep_index = to_keep_index + front_index

            # Update front & number of solutions
            return to_keep_index, front_index

        def condition_fn_1(val: Tuple[jax.Array, jax.Array]) -> bool:
            """Gives condition to stop the while loop. Makes sure the
            the number of solution is smaller than the maximum size
            of the population.

            Args:
                val: Value passed through the while loop. Here, it is
                    a tuple containing two values. The indexes of all
                    solutions to keep and the indexes of the last
                    computed front.

            Returns:
                Returns True if we have reached the maximum number of
                solutions we can keep in the population.
            """
            to_keep_index, _ = val
            return sum(to_keep_index) < self.size  # type: ignore

        # get indexes of all first successive fronts and indexes of the last front
        to_keep_index, front_index = jax.lax.while_loop(
            condition_fn_1,
            compute_current_front,
            (
                jnp.zeros(num_candidates, dtype=bool),
                jnp.zeros(num_candidates, dtype=bool),
            ),
        )

        # remove the indexes of the last front - gives first indexes to keep
        new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
        new_index = new_index * (~front_index)
        to_keep_index = new_index > 0

        # Compute crowding distances
        crowding_distances = self._compute_crowding_distances(
            candidate_fitnesses, ~front_index
        )
        crowding_distances = crowding_distances * (front_index)
        highest_dist = jnp.argsort(crowding_distances)

        def add_to_front(val: Tuple[jax.Array, float]) -> Tuple[jax.Array, Any]:
            """Add the individual with a given distance to the front.
            A index is incremented to get the highest from the non
            selected individuals.

            Args:
                val: a tuple of two elements. A boolean vector with the positions that
                    will be kept, and a cursor with the number of individuals already
                    added during this process.

            Returns:
                The updated tuple, with the new booleans and the number of
                added elements.
            """
            front_index, num = val
            front_index = front_index.at[highest_dist[-num]].set(True)
            num = num + 1
            val = front_index, num
            return val

        def condition_fn_2(val: Tuple[jax.Array, jax.Array]) -> bool:
            """Gives condition to stop the while loop. Makes sure the
            the number of solution is smaller than the maximum size
            of the population."""
            front_index, _ = val
            return sum(to_keep_index + front_index) < self.size  # type: ignore

        # add the individuals with the highest distances
        front_index, _num = jax.lax.while_loop(
            condition_fn_2,
            add_to_front,
            (jnp.zeros(num_candidates, dtype=bool), 0),
        )

        # update index
        to_keep_index = to_keep_index + front_index

        # go from boolean vector to indices - offset by 1
        indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index

        # get rid of the zeros (that correspond to the False from the mask)
        fake_indice = num_candidates + 1  # bigger than all the other indices
        indices = jnp.where(indices == 0, fake_indice, indices)

        # sort the indices to remove the fake indices
        indices = jnp.sort(indices)[: self.size]

        # remove the offset
        indices = indices - 1

        # keep only the survivors
        new_candidates = jax.tree.map(lambda x: x[indices], candidates)
        new_scores = candidate_fitnesses[indices]

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)
        new_extra_scores = jax.tree.map(
            lambda x: x[indices], filtered_batch_of_extra_scores
        )
        new_repertoire = self.replace(
            genotypes=new_candidates,
            fitnesses=new_scores,
            extra_scores=new_extra_scores,
        )

        return new_repertoire  # type: ignore

add(batch_of_genotypes, batch_of_fitnesses, batch_of_extra_scores=None)

Implements the repertoire addition rules.

The population is sorted in successive pareto front. The first one is the global pareto front. The second one is the pareto front of the population where the first pareto front has been removed, etc...

The successive pareto fronts are kept until the moment where adding a full pareto front would exceed the population size.

To decide the survival of this pareto front, a crowding distance is computed in order to keep individuals that are spread in this last pareto front. Hence, the individuals with the biggest crowding distances are added until the population size is reached.

Parameters:
  • batch_of_genotypes (Genotype) –

    new genotypes that we try to add.

  • batch_of_fitnesses (Fitness) –

    fitness of those new genotypes.

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of those new genotypes.

Returns:
Source code in qdax/core/containers/nsga2_repertoire.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> NSGA2Repertoire:
    """Implements the repertoire addition rules.

    The population is sorted in successive pareto front. The first one
    is the global pareto front. The second one is the pareto front of the
    population where the first pareto front has been removed, etc...

    The successive pareto fronts are kept until the moment where adding a
    full pareto front would exceed the population size.

    To decide the survival of this pareto front, a crowding distance is
    computed in order to keep individuals that are spread in this last pareto
    front. Hence, the individuals with the biggest crowding distances are
    added until the population size is reached.

    Args:
        batch_of_genotypes: new genotypes that we try to add.
        batch_of_fitnesses: fitness of those new genotypes.
        batch_of_extra_scores: extra scores of those new genotypes.

    Returns:
        The updated repertoire.
    """

    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    # All the candidates
    candidates = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )

    candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

    first_leaf = jax.tree.leaves(candidates)[0]
    num_candidates = first_leaf.shape[0]

    def compute_current_front(
        val: Tuple[jax.Array, jax.Array]
    ) -> Tuple[jax.Array, jax.Array]:
        """Body function for the while loop. Computes the successive
        pareto fronts in the data.

        Args:
            val: Value passed through the while loop. Here, it is
                a tuple containing two values. The indexes of all
                solutions to keep and the indexes of the last
                computed front.

        Returns:
            The updated values to pass through the while loop. Updated
            number of solutions and updated front indexes.
        """
        to_keep_index, _ = val

        # mask the individual that are already kept
        front_index = compute_masked_pareto_front(
            candidate_fitnesses, mask=to_keep_index
        )

        # Add new indexes
        to_keep_index = to_keep_index + front_index

        # Update front & number of solutions
        return to_keep_index, front_index

    def condition_fn_1(val: Tuple[jax.Array, jax.Array]) -> bool:
        """Gives condition to stop the while loop. Makes sure the
        the number of solution is smaller than the maximum size
        of the population.

        Args:
            val: Value passed through the while loop. Here, it is
                a tuple containing two values. The indexes of all
                solutions to keep and the indexes of the last
                computed front.

        Returns:
            Returns True if we have reached the maximum number of
            solutions we can keep in the population.
        """
        to_keep_index, _ = val
        return sum(to_keep_index) < self.size  # type: ignore

    # get indexes of all first successive fronts and indexes of the last front
    to_keep_index, front_index = jax.lax.while_loop(
        condition_fn_1,
        compute_current_front,
        (
            jnp.zeros(num_candidates, dtype=bool),
            jnp.zeros(num_candidates, dtype=bool),
        ),
    )

    # remove the indexes of the last front - gives first indexes to keep
    new_index = jnp.arange(start=1, stop=len(to_keep_index) + 1) * to_keep_index
    new_index = new_index * (~front_index)
    to_keep_index = new_index > 0

    # Compute crowding distances
    crowding_distances = self._compute_crowding_distances(
        candidate_fitnesses, ~front_index
    )
    crowding_distances = crowding_distances * (front_index)
    highest_dist = jnp.argsort(crowding_distances)

    def add_to_front(val: Tuple[jax.Array, float]) -> Tuple[jax.Array, Any]:
        """Add the individual with a given distance to the front.
        A index is incremented to get the highest from the non
        selected individuals.

        Args:
            val: a tuple of two elements. A boolean vector with the positions that
                will be kept, and a cursor with the number of individuals already
                added during this process.

        Returns:
            The updated tuple, with the new booleans and the number of
            added elements.
        """
        front_index, num = val
        front_index = front_index.at[highest_dist[-num]].set(True)
        num = num + 1
        val = front_index, num
        return val

    def condition_fn_2(val: Tuple[jax.Array, jax.Array]) -> bool:
        """Gives condition to stop the while loop. Makes sure the
        the number of solution is smaller than the maximum size
        of the population."""
        front_index, _ = val
        return sum(to_keep_index + front_index) < self.size  # type: ignore

    # add the individuals with the highest distances
    front_index, _num = jax.lax.while_loop(
        condition_fn_2,
        add_to_front,
        (jnp.zeros(num_candidates, dtype=bool), 0),
    )

    # update index
    to_keep_index = to_keep_index + front_index

    # go from boolean vector to indices - offset by 1
    indices = jnp.arange(start=1, stop=num_candidates + 1) * to_keep_index

    # get rid of the zeros (that correspond to the False from the mask)
    fake_indice = num_candidates + 1  # bigger than all the other indices
    indices = jnp.where(indices == 0, fake_indice, indices)

    # sort the indices to remove the fake indices
    indices = jnp.sort(indices)[: self.size]

    # remove the offset
    indices = indices - 1

    # keep only the survivors
    new_candidates = jax.tree.map(lambda x: x[indices], candidates)
    new_scores = candidate_fitnesses[indices]

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)
    new_extra_scores = jax.tree.map(
        lambda x: x[indices], filtered_batch_of_extra_scores
    )
    new_repertoire = self.replace(
        genotypes=new_candidates,
        fitnesses=new_scores,
        extra_scores=new_extra_scores,
    )

    return new_repertoire  # type: ignore

Repertoire

Bases: PyTreeNode, ABC

Abstract class for any repertoire of genotypes.

We decided not to add the attributes Genotypes even if it will be shared by all children classes because we want to keep the parent classes explicit and transparent.

Source code in qdax/core/containers/repertoire.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class Repertoire(flax.struct.PyTreeNode, ABC):
    """Abstract class for any repertoire of genotypes.

    We decided not to add the attributes Genotypes even if
    it will be shared by all children classes because we want
    to keep the parent classes explicit and transparent.
    """

    @classmethod
    @abstractmethod
    def init(cls) -> Repertoire:  # noqa: N805
        """Create a repertoire."""
        pass

    @abstractmethod
    def select(
        self,
        key: RNGKey,
        num_samples: int,
        selector: Optional[Selector] = None,
    ) -> Repertoire:
        """Selects individuals from the repertoire.

        Args:
            key: The random key to use for the selection.
            num_samples: The number of individuals to select.
            selector: The selector to use for the selection.

        Returns:
            A repertoire containing the selected individuals.
        """
        pass

    @abstractmethod
    def add(self) -> Repertoire:  # noqa: N805
        """Implements the rule to add new genotypes to a
        repertoire.

        Returns:
            The updated repertoire.
        """
        pass

add() abstractmethod

Implements the rule to add new genotypes to a repertoire.

Returns:
Source code in qdax/core/containers/repertoire.py
49
50
51
52
53
54
55
56
57
@abstractmethod
def add(self) -> Repertoire:  # noqa: N805
    """Implements the rule to add new genotypes to a
    repertoire.

    Returns:
        The updated repertoire.
    """
    pass

init() abstractmethod classmethod

Create a repertoire.

Source code in qdax/core/containers/repertoire.py
24
25
26
27
28
@classmethod
@abstractmethod
def init(cls) -> Repertoire:  # noqa: N805
    """Create a repertoire."""
    pass

select(key, num_samples, selector=None) abstractmethod

Selects individuals from the repertoire.

Parameters:
  • key (RNGKey) –

    The random key to use for the selection.

  • num_samples (int) –

    The number of individuals to select.

  • selector (Optional[Selector], default: None ) –

    The selector to use for the selection.

Returns:
  • Repertoire

    A repertoire containing the selected individuals.

Source code in qdax/core/containers/repertoire.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@abstractmethod
def select(
    self,
    key: RNGKey,
    num_samples: int,
    selector: Optional[Selector] = None,
) -> Repertoire:
    """Selects individuals from the repertoire.

    Args:
        key: The random key to use for the selection.
        num_samples: The number of individuals to select.
        selector: The selector to use for the selection.

    Returns:
        A repertoire containing the selected individuals.
    """
    pass

SPEA2Repertoire

Bases: GARepertoire

Repertoire used for the SPEA2 algorithm.

Inherits from the GARepertoire. The data stored are the genotypes and there fitness. Several functions are inherited from GARepertoire, including size, sample.

Source code in qdax/core/containers/spea2_repertoire.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class SPEA2Repertoire(GARepertoire):
    """Repertoire used for the SPEA2 algorithm.

    Inherits from the GARepertoire. The data stored are the genotypes
    and there fitness. Several functions are inherited from GARepertoire,
    including size, sample.
    """

    num_neighbours: int = flax.struct.field(pytree_node=False)

    def _compute_strength_scores(self, batch_of_fitnesses: Fitness) -> jax.Array:
        """Compute the strength scores (defined for a solution by the number of
        solutions dominating it plus the inverse of the density of solution in the
        fitness space).

        Args:
            batch_of_fitnesses: a batch of fitness vectors.

        Returns:
            Strength score of each solution corresponding to the fitnesses.
        """
        fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses), axis=0)
        # dominating solutions
        dominates = jnp.all(
            (fitnesses - jnp.expand_dims(fitnesses, axis=1)) > 0, axis=-1
        )
        strength_scores = jnp.sum(dominates, axis=1)

        # density
        distance_matrix = jnp.sum(
            (fitnesses - jnp.expand_dims(fitnesses, axis=1)) ** 2, axis=-1
        )
        densities = jnp.sum(
            jnp.sort(distance_matrix, axis=1)[:, : self.num_neighbours + 1], axis=1
        )

        # sum both terms
        strength_scores = strength_scores + 1 / (1 + densities)
        strength_scores = jnp.nan_to_num(strength_scores, nan=self.size + 2)

        return strength_scores

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> SPEA2Repertoire:
        """Updates the population with the new solutions.

        To decide which individuals to keep, we count, for each solution,
        the number of solutions by which they are dominated. We keep only
        the solutions that are the less dominated ones.

        Args:
            batch_of_genotypes: genotypes of the new individuals that are
                considered to be added to the population.
            batch_of_fitnesses: their corresponding fitnesses.
            batch_of_extra_scores: extra scores of those new genotypes.

        Returns:
            Updated repertoire.
        """

        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        # All the candidates
        candidates = jax.tree.map(
            lambda x, y: jnp.concatenate((x, y), axis=0),
            self.genotypes,
            batch_of_genotypes,
        )

        candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

        # compute strength score for all solutions
        strength_scores = self._compute_strength_scores(batch_of_fitnesses)

        # sort the strengths (the smaller the better (sic, respect paper's notation))
        indices = jnp.argsort(strength_scores)[: self.size]

        # keep only the survivors
        new_candidates = jax.tree.map(lambda x: x[indices], candidates)
        new_scores = candidates_fitnesses[indices]

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)
        new_extra_scores = jax.tree.map(
            lambda x: x[indices], filtered_batch_of_extra_scores
        )
        new_repertoire = self.replace(
            genotypes=new_candidates,
            fitnesses=new_scores,
            extra_scores=new_extra_scores,
        )

        return new_repertoire  # type: ignore

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        population_size: int,
        num_neighbours: int,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> GARepertoire:
        """Initializes the repertoire.

        Start with default values and adds a first batch of genotypes
        to the repertoire.

        Args:
            genotypes: first batch of genotypes
            fitnesses: corresponding fitnesses
            population_size: size of the population we want to evolve
            extra_scores: extra scores resulting from the evaluation of the genotypes
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            An initial repertoire.
        """

        if extra_scores is None:
            extra_scores = {}

        # create default fitnesses
        default_fitnesses = -jnp.inf * jnp.ones(
            shape=(population_size, fitnesses.shape[-1])
        )

        # create default genotypes
        default_genotypes = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
        )

        # create default extra scores
        filtered_extra_scores = {
            key: value
            for key, value in extra_scores.items()
            if key in keys_extra_scores
        }

        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]),
            filtered_extra_scores,
        )

        # create an initial repertoire with those default values
        repertoire = cls(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
            num_neighbours=num_neighbours,
        )

        new_repertoire = repertoire.add(genotypes, fitnesses, extra_scores)

        return new_repertoire  # type: ignore

add(batch_of_genotypes, batch_of_fitnesses, batch_of_extra_scores=None)

Updates the population with the new solutions.

To decide which individuals to keep, we count, for each solution, the number of solutions by which they are dominated. We keep only the solutions that are the less dominated ones.

Parameters:
  • batch_of_genotypes (Genotype) –

    genotypes of the new individuals that are considered to be added to the population.

  • batch_of_fitnesses (Fitness) –

    their corresponding fitnesses.

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    extra scores of those new genotypes.

Returns:
Source code in qdax/core/containers/spea2_repertoire.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> SPEA2Repertoire:
    """Updates the population with the new solutions.

    To decide which individuals to keep, we count, for each solution,
    the number of solutions by which they are dominated. We keep only
    the solutions that are the less dominated ones.

    Args:
        batch_of_genotypes: genotypes of the new individuals that are
            considered to be added to the population.
        batch_of_fitnesses: their corresponding fitnesses.
        batch_of_extra_scores: extra scores of those new genotypes.

    Returns:
        Updated repertoire.
    """

    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    # All the candidates
    candidates = jax.tree.map(
        lambda x, y: jnp.concatenate((x, y), axis=0),
        self.genotypes,
        batch_of_genotypes,
    )

    candidates_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses))

    # compute strength score for all solutions
    strength_scores = self._compute_strength_scores(batch_of_fitnesses)

    # sort the strengths (the smaller the better (sic, respect paper's notation))
    indices = jnp.argsort(strength_scores)[: self.size]

    # keep only the survivors
    new_candidates = jax.tree.map(lambda x: x[indices], candidates)
    new_scores = candidates_fitnesses[indices]

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)
    new_extra_scores = jax.tree.map(
        lambda x: x[indices], filtered_batch_of_extra_scores
    )
    new_repertoire = self.replace(
        genotypes=new_candidates,
        fitnesses=new_scores,
        extra_scores=new_extra_scores,
    )

    return new_repertoire  # type: ignore

init(genotypes, fitnesses, population_size, num_neighbours, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initializes the repertoire.

Start with default values and adds a first batch of genotypes to the repertoire.

Parameters:
  • genotypes (Genotype) –

    first batch of genotypes

  • fitnesses (Fitness) –

    corresponding fitnesses

  • population_size (int) –

    size of the population we want to evolve

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/spea2_repertoire.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    population_size: int,
    num_neighbours: int,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> GARepertoire:
    """Initializes the repertoire.

    Start with default values and adds a first batch of genotypes
    to the repertoire.

    Args:
        genotypes: first batch of genotypes
        fitnesses: corresponding fitnesses
        population_size: size of the population we want to evolve
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        An initial repertoire.
    """

    if extra_scores is None:
        extra_scores = {}

    # create default fitnesses
    default_fitnesses = -jnp.inf * jnp.ones(
        shape=(population_size, fitnesses.shape[-1])
    )

    # create default genotypes
    default_genotypes = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
    )

    # create default extra scores
    filtered_extra_scores = {
        key: value
        for key, value in extra_scores.items()
        if key in keys_extra_scores
    }

    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]),
        filtered_extra_scores,
    )

    # create an initial repertoire with those default values
    repertoire = cls(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
        num_neighbours=num_neighbours,
    )

    new_repertoire = repertoire.add(genotypes, fitnesses, extra_scores)

    return new_repertoire  # type: ignore

UniformReplacementArchive

Bases: Archive

Stores jax.Array and use a uniform replacement when the maximum size is reached.

Instead of replacing elements in a FIFO manner, like the Archive, this implementation removes elements uniformly to replace them by the newly added ones.

Most methods are inherited from Archive.

Source code in qdax/core/containers/uniform_replacement_archive.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class UniformReplacementArchive(Archive):
    """Stores jax.Array and use a uniform replacement when the
    maximum size is reached.

    Instead of replacing elements in a FIFO manner, like the Archive,
    this implementation removes elements uniformly to replace them by
    the newly added ones.

    Most methods are inherited from Archive.
    """

    key: RNGKey

    @classmethod
    def create(  # type: ignore
        cls,
        acceptance_threshold: float,
        state_descriptor_size: int,
        max_size: int,
        key: RNGKey,
    ) -> Archive:
        """Create an Archive instance.

        This class method provides a convenient way to create the archive while
        keeping the __init__ function for more general way to init an archive.

        Args:
            acceptance_threshold: the minimal distance to a stored descriptor to
                be respected for a new descriptor to be added.
            state_descriptor_size: the number of elements in a state descriptor.
            max_size: the maximal size of the archive. In case of overflow, previous
                elements are replaced by new ones. Defaults to 80000.
            key: a key to handle random operations. Defaults to key with
                seed = 0.

        Returns:
            A newly initialized archive.
        """

        archive = super().create(
            acceptance_threshold,
            state_descriptor_size,
            max_size,
        )

        return archive.replace(key=key)  # type: ignore

    def _single_insertion(self, state_descriptor: jax.Array) -> Archive:
        """Insert a single element.

        If the archive is not full yet, the new element replaces a fake
        border, if it is full, it replaces a random element from the archive.

        Args:
            state_descriptor: state descriptor to be added.

        Returns:
            Return the archive with the newly added element."""
        new_current_position = self.current_position + 1
        is_full = new_current_position >= self.max_size

        key, subkey = jax.random.split(self.key)
        random_index = jax.random.randint(
            subkey, shape=(1,), minval=0, maxval=self.max_size
        )

        index = jnp.where(is_full, random_index, new_current_position)

        new_data = self.data.at[index].set(state_descriptor)

        return self.replace(  # type: ignore
            current_position=new_current_position, data=new_data, key=key
        )

create(acceptance_threshold, state_descriptor_size, max_size, key) classmethod

Create an Archive instance.

This class method provides a convenient way to create the archive while keeping the init function for more general way to init an archive.

Parameters:
  • acceptance_threshold (float) –

    the minimal distance to a stored descriptor to be respected for a new descriptor to be added.

  • state_descriptor_size (int) –

    the number of elements in a state descriptor.

  • max_size (int) –

    the maximal size of the archive. In case of overflow, previous elements are replaced by new ones. Defaults to 80000.

  • key (RNGKey) –

    a key to handle random operations. Defaults to key with seed = 0.

Returns:
  • Archive

    A newly initialized archive.

Source code in qdax/core/containers/uniform_replacement_archive.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@classmethod
def create(  # type: ignore
    cls,
    acceptance_threshold: float,
    state_descriptor_size: int,
    max_size: int,
    key: RNGKey,
) -> Archive:
    """Create an Archive instance.

    This class method provides a convenient way to create the archive while
    keeping the __init__ function for more general way to init an archive.

    Args:
        acceptance_threshold: the minimal distance to a stored descriptor to
            be respected for a new descriptor to be added.
        state_descriptor_size: the number of elements in a state descriptor.
        max_size: the maximal size of the archive. In case of overflow, previous
            elements are replaced by new ones. Defaults to 80000.
        key: a key to handle random operations. Defaults to key with
            seed = 0.

    Returns:
        A newly initialized archive.
    """

    archive = super().create(
        acceptance_threshold,
        state_descriptor_size,
        max_size,
    )

    return archive.replace(key=key)  # type: ignore

UnstructuredRepertoire

Bases: GARepertoire

Class for the unstructured repertoire in Map Elites.

Parameters:
  • genotypes

    a PyTree containing all the genotypes in the repertoire ordered by the centroids. Each leaf has a shape (num_centroids, num_features). The PyTree can be a simple JAX array or a more complex nested structure such as to represent parameters of neural network in Flax.

  • fitnesses

    an array that contains the fitness of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids,).

  • descriptors

    an array that contains the descriptors of solutions in each cell of the repertoire, ordered by centroids. The array shape is (num_centroids, num_descriptors).

  • extra_scores

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores

    keys of the extra scores to store in the repertoire

Source code in qdax/core/containers/unstructured_repertoire.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class UnstructuredRepertoire(GARepertoire):
    """
    Class for the unstructured repertoire in Map Elites.

    Args:
        genotypes: a PyTree containing all the genotypes in the repertoire ordered
            by the centroids. Each leaf has a shape (num_centroids, num_features). The
            PyTree can be a simple JAX array or a more complex nested structure such
            as to represent parameters of neural network in Flax.
        fitnesses: an array that contains the fitness of solutions in each cell of the
            repertoire, ordered by centroids. The array shape is (num_centroids,).
        descriptors: an array that contains the descriptors of solutions in each cell
            of the repertoire, ordered by centroids. The array shape
            is (num_centroids, num_descriptors).
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire
    """

    descriptors: Descriptor
    l_value: jax.Array
    max_size: int = flax.struct.field(pytree_node=False)

    def get_maximal_size(self) -> int:
        """Returns the maximal number of individuals in the repertoire."""
        return self.max_size

    def get_number_genotypes(self) -> jax.Array:
        """Returns the number of genotypes in the repertoire."""
        return jnp.sum(self.fitnesses != -jnp.inf)

    def add(  # type: ignore
        self,
        batch_of_genotypes: Genotype,
        batch_of_descriptors: Descriptor,
        batch_of_fitnesses: Fitness,
        batch_of_extra_scores: Optional[ExtraScores] = None,
    ) -> UnstructuredRepertoire:
        """Adds a batch of genotypes to the repertoire.

        Args:
            batch_of_genotypes: genotypes of the individuals to be considered
                for addition in the repertoire.
            batch_of_descriptors: associated descriptors.
            batch_of_fitnesses: associated fitness.
            batch_of_extra_scores: associated extra scores.

        Returns:
            A new unstructured repertoire where the relevant individuals have been
            added.
        """
        if batch_of_extra_scores is None:
            batch_of_extra_scores = {}

        filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

        batch_of_fitnesses = batch_of_fitnesses.reshape(-1, 1)

        # We need to replace all the descriptors that are not filled with jnp inf
        filtered_descriptors = jnp.where(
            jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
            jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
            self.descriptors,
        )

        batch_of_indices, batch_of_distances = get_cells_indices(
            batch_of_descriptors, filtered_descriptors, 2
        )

        # Save the second-nearest neighbours to check a condition
        second_neighbours = batch_of_distances.at[..., 1].get()

        # Keep the Nearest neighbours
        batch_of_indices = batch_of_indices.at[..., 0].get()

        # Keep the Nearest neighbours
        batch_of_distances = batch_of_distances.at[..., 0].get()

        # We remove individuals that are too close to the second nn.
        # This avoids having clusters of individuals after adding them.
        not_novel_enough = jnp.where(
            jnp.squeeze(second_neighbours <= self.l_value[0]), True, False
        )

        # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
        # batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
        filtered_batch_of_extra_scores = jax.tree.map(
            lambda x: jnp.expand_dims(x, axis=-1), filtered_batch_of_extra_scores
        )

        # TODO: Doesn't Work if Archive is full. Need to use the closest individuals
        # in that case.
        empty_indexes = jnp.squeeze(
            jnp.nonzero(
                jnp.where(jnp.isinf(self.fitnesses), 1, 0),
                size=batch_of_indices.shape[0],
                fill_value=-1,
            )[0]
        )
        batch_of_indices = jnp.where(
            jnp.squeeze(batch_of_distances <= self.l_value[0]),
            jnp.squeeze(batch_of_indices),
            -1,
        )

        # We get all the indices of the empty descriptors first and then the filled ones
        # (because of -1)
        sorted_descriptors = jax.lax.top_k(
            -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
        )[1]
        batch_of_indices = jnp.where(
            jnp.squeeze(
                batch_of_distances.at[sorted_descriptors].get() <= self.l_value[0]
            ),
            batch_of_indices.at[sorted_descriptors].get(),
            empty_indexes,
        )

        batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

        # ReIndexing of all the inputs to the correct sorted way
        batch_of_descriptors = batch_of_descriptors.at[sorted_descriptors].get()
        batch_of_genotypes = jax.tree.map(
            lambda x: x.at[sorted_descriptors].get(), batch_of_genotypes
        )
        batch_of_fitnesses = batch_of_fitnesses.at[sorted_descriptors].get()

        filtered_batch_of_extra_scores = jax.tree.map(
            lambda x: x.at[sorted_descriptors].get(), filtered_batch_of_extra_scores
        )
        not_novel_enough = not_novel_enough.at[sorted_descriptors].get()

        # Check to find Individuals with same descriptor within the Batch
        keep_indiv = jax.jit(
            jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
        )(
            batch_of_descriptors.squeeze(),
            jnp.arange(
                0, batch_of_descriptors.shape[0], 1
            ),  # keep track of where we are in the batch to assure right comparisons
            batch_of_descriptors.squeeze(),
            batch_of_fitnesses.squeeze(),
            self.l_value[0],
        )

        keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))

        # get fitness segment max
        best_fitnesses = jax.ops.segment_max(
            batch_of_fitnesses,
            batch_of_indices.astype(jnp.int32).squeeze(),
            num_segments=self.max_size,
        )

        cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

        # put dominated fitness to -jnp.inf
        batch_of_fitnesses = jnp.where(
            batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf
        )

        # get addition condition
        current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)
        addition_condition = batch_of_fitnesses > current_fitnesses
        addition_condition = jnp.logical_and(
            addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
        )

        # assign fake position when relevant : num_centroids is out of bounds
        batch_of_indices = jnp.where(
            addition_condition,
            batch_of_indices,
            self.max_size,
        )

        # create new grid
        new_grid_genotypes = jax.tree.map(
            lambda grid_genotypes, new_genotypes: grid_genotypes.at[
                batch_of_indices.squeeze()
            ].set(new_genotypes),
            self.genotypes,
            batch_of_genotypes,
        )

        # compute new fitness and descriptors
        new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
            batch_of_fitnesses
        )
        new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
            batch_of_descriptors.squeeze()
        )

        new_extra_scores = jax.tree.map(
            lambda x, y: x.at[batch_of_indices.squeeze()].set(y.squeeze()).squeeze(),
            self.extra_scores,
            filtered_batch_of_extra_scores,
        )

        return UnstructuredRepertoire(
            genotypes=new_grid_genotypes,
            fitnesses=new_fitnesses,
            descriptors=new_descriptors.squeeze(),
            extra_scores=new_extra_scores,
            keys_extra_scores=self.keys_extra_scores,
            l_value=self.l_value,
            max_size=self.max_size,
        )

    def select(
        self,
        key: RNGKey,
        num_samples: int,
        selector: Optional[Selector[UnstructuredRepertoire]] = None,
    ) -> UnstructuredRepertoire:
        """Select elements in the repertoire.

        This method sample a non-empty pareto front, and then sample
        genotypes from this pareto front.

        Args:
            key: a random key to handle stochasticity.
            num_samples: number of samples to retrieve from the repertoire.
            selector: selector to choose the individuals. Defaults to None.

        Returns:
            A repertoire containing the selected individuals.
        """

        if selector is None:
            selector = UniformSelector(select_with_replacement=True)

        # Explicitly cast return value to UnstructuredRepertoire
        repertoire: UnstructuredRepertoire = selector.select(self, key, num_samples)

        return repertoire

    @classmethod
    def init(  # type: ignore
        cls,
        genotypes: Genotype,
        fitnesses: Fitness,
        descriptors: Descriptor,
        l_value: jax.Array,
        max_size: int,
        *args,
        extra_scores: Optional[ExtraScores] = None,
        keys_extra_scores: Tuple[str, ...] = (),
        **kwargs,
    ) -> UnstructuredRepertoire:
        """Initialize a Map-Elites repertoire with an initial population of genotypes.
        Requires the definition of centroids that can be computed with any method
        such as CVT or Euclidean mapping.

        Args:
            genotypes: initial genotypes, pytree in which leaves
                have shape (batch_size, num_features)
            fitnesses: fitness of the initial genotypes of shape (batch_size,)
            descriptors: descriptors of the initial genotypes
                of shape (batch_size, num_descriptors)
            l_value: threshold distance of the repertoire.
            max_size: maximal size of the container
            extra_scores: extra scores resulting from the evaluation of the genotypes
            keys_extra_scores: keys of the extra scores to store in the repertoire

        Returns:
            an initialized unstructured repertoire.
        """

        if extra_scores is None:
            extra_scores = {}

        # Initialize grid with default values
        default_fitnesses = -jnp.inf * jnp.ones(shape=(max_size, 1))
        default_genotypes = jax.tree.map(
            lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
            genotypes,
        )
        default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))

        # create default extra scores
        filtered_extra_scores = {
            key: value
            for key, value in extra_scores.items()
            if key in keys_extra_scores
        }

        default_extra_scores = jax.tree.map(
            lambda x: jnp.zeros(shape=(max_size,) + x.shape[1:]),
            filtered_extra_scores,
        )

        repertoire = UnstructuredRepertoire(
            genotypes=default_genotypes,
            fitnesses=default_fitnesses,
            descriptors=default_descriptors,
            l_value=jnp.full(shape=(max_size,), fill_value=l_value),
            max_size=max_size,
            extra_scores=default_extra_scores,
            keys_extra_scores=keys_extra_scores,
        )

        return repertoire.add(  # type: ignore
            genotypes,
            descriptors,
            fitnesses,
            extra_scores,
        )

add(batch_of_genotypes, batch_of_descriptors, batch_of_fitnesses, batch_of_extra_scores=None)

Adds a batch of genotypes to the repertoire.

Parameters:
  • batch_of_genotypes (Genotype) –

    genotypes of the individuals to be considered for addition in the repertoire.

  • batch_of_descriptors (Descriptor) –

    associated descriptors.

  • batch_of_fitnesses (Fitness) –

    associated fitness.

  • batch_of_extra_scores (Optional[ExtraScores], default: None ) –

    associated extra scores.

Returns:
Source code in qdax/core/containers/unstructured_repertoire.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def add(  # type: ignore
    self,
    batch_of_genotypes: Genotype,
    batch_of_descriptors: Descriptor,
    batch_of_fitnesses: Fitness,
    batch_of_extra_scores: Optional[ExtraScores] = None,
) -> UnstructuredRepertoire:
    """Adds a batch of genotypes to the repertoire.

    Args:
        batch_of_genotypes: genotypes of the individuals to be considered
            for addition in the repertoire.
        batch_of_descriptors: associated descriptors.
        batch_of_fitnesses: associated fitness.
        batch_of_extra_scores: associated extra scores.

    Returns:
        A new unstructured repertoire where the relevant individuals have been
        added.
    """
    if batch_of_extra_scores is None:
        batch_of_extra_scores = {}

    filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

    batch_of_fitnesses = batch_of_fitnesses.reshape(-1, 1)

    # We need to replace all the descriptors that are not filled with jnp inf
    filtered_descriptors = jnp.where(
        jnp.expand_dims((self.fitnesses == -jnp.inf), axis=-1),
        jnp.full(self.descriptors.shape[-1], fill_value=jnp.inf),
        self.descriptors,
    )

    batch_of_indices, batch_of_distances = get_cells_indices(
        batch_of_descriptors, filtered_descriptors, 2
    )

    # Save the second-nearest neighbours to check a condition
    second_neighbours = batch_of_distances.at[..., 1].get()

    # Keep the Nearest neighbours
    batch_of_indices = batch_of_indices.at[..., 0].get()

    # Keep the Nearest neighbours
    batch_of_distances = batch_of_distances.at[..., 0].get()

    # We remove individuals that are too close to the second nn.
    # This avoids having clusters of individuals after adding them.
    not_novel_enough = jnp.where(
        jnp.squeeze(second_neighbours <= self.l_value[0]), True, False
    )

    # batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
    # batch_of_fitnesses = jnp.expand_dims(batch_of_fitnesses, axis=-1)
    filtered_batch_of_extra_scores = jax.tree.map(
        lambda x: jnp.expand_dims(x, axis=-1), filtered_batch_of_extra_scores
    )

    # TODO: Doesn't Work if Archive is full. Need to use the closest individuals
    # in that case.
    empty_indexes = jnp.squeeze(
        jnp.nonzero(
            jnp.where(jnp.isinf(self.fitnesses), 1, 0),
            size=batch_of_indices.shape[0],
            fill_value=-1,
        )[0]
    )
    batch_of_indices = jnp.where(
        jnp.squeeze(batch_of_distances <= self.l_value[0]),
        jnp.squeeze(batch_of_indices),
        -1,
    )

    # We get all the indices of the empty descriptors first and then the filled ones
    # (because of -1)
    sorted_descriptors = jax.lax.top_k(
        -1 * batch_of_indices.squeeze(), batch_of_indices.shape[0]
    )[1]
    batch_of_indices = jnp.where(
        jnp.squeeze(
            batch_of_distances.at[sorted_descriptors].get() <= self.l_value[0]
        ),
        batch_of_indices.at[sorted_descriptors].get(),
        empty_indexes,
    )

    batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)

    # ReIndexing of all the inputs to the correct sorted way
    batch_of_descriptors = batch_of_descriptors.at[sorted_descriptors].get()
    batch_of_genotypes = jax.tree.map(
        lambda x: x.at[sorted_descriptors].get(), batch_of_genotypes
    )
    batch_of_fitnesses = batch_of_fitnesses.at[sorted_descriptors].get()

    filtered_batch_of_extra_scores = jax.tree.map(
        lambda x: x.at[sorted_descriptors].get(), filtered_batch_of_extra_scores
    )
    not_novel_enough = not_novel_enough.at[sorted_descriptors].get()

    # Check to find Individuals with same descriptor within the Batch
    keep_indiv = jax.jit(
        jax.vmap(intra_batch_comp, in_axes=(0, 0, None, None, None), out_axes=(0))
    )(
        batch_of_descriptors.squeeze(),
        jnp.arange(
            0, batch_of_descriptors.shape[0], 1
        ),  # keep track of where we are in the batch to assure right comparisons
        batch_of_descriptors.squeeze(),
        batch_of_fitnesses.squeeze(),
        self.l_value[0],
    )

    keep_indiv = jnp.logical_and(keep_indiv, jnp.logical_not(not_novel_enough))

    # get fitness segment max
    best_fitnesses = jax.ops.segment_max(
        batch_of_fitnesses,
        batch_of_indices.astype(jnp.int32).squeeze(),
        num_segments=self.max_size,
    )

    cond_values = jnp.take_along_axis(best_fitnesses, batch_of_indices, 0)

    # put dominated fitness to -jnp.inf
    batch_of_fitnesses = jnp.where(
        batch_of_fitnesses == cond_values, batch_of_fitnesses, -jnp.inf
    )

    # get addition condition
    current_fitnesses = jnp.take_along_axis(self.fitnesses, batch_of_indices, 0)
    addition_condition = batch_of_fitnesses > current_fitnesses
    addition_condition = jnp.logical_and(
        addition_condition, jnp.expand_dims(keep_indiv, axis=-1)
    )

    # assign fake position when relevant : num_centroids is out of bounds
    batch_of_indices = jnp.where(
        addition_condition,
        batch_of_indices,
        self.max_size,
    )

    # create new grid
    new_grid_genotypes = jax.tree.map(
        lambda grid_genotypes, new_genotypes: grid_genotypes.at[
            batch_of_indices.squeeze()
        ].set(new_genotypes),
        self.genotypes,
        batch_of_genotypes,
    )

    # compute new fitness and descriptors
    new_fitnesses = self.fitnesses.at[batch_of_indices.squeeze()].set(
        batch_of_fitnesses
    )
    new_descriptors = self.descriptors.at[batch_of_indices.squeeze()].set(
        batch_of_descriptors.squeeze()
    )

    new_extra_scores = jax.tree.map(
        lambda x, y: x.at[batch_of_indices.squeeze()].set(y.squeeze()).squeeze(),
        self.extra_scores,
        filtered_batch_of_extra_scores,
    )

    return UnstructuredRepertoire(
        genotypes=new_grid_genotypes,
        fitnesses=new_fitnesses,
        descriptors=new_descriptors.squeeze(),
        extra_scores=new_extra_scores,
        keys_extra_scores=self.keys_extra_scores,
        l_value=self.l_value,
        max_size=self.max_size,
    )

get_maximal_size()

Returns the maximal number of individuals in the repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
154
155
156
def get_maximal_size(self) -> int:
    """Returns the maximal number of individuals in the repertoire."""
    return self.max_size

get_number_genotypes()

Returns the number of genotypes in the repertoire.

Source code in qdax/core/containers/unstructured_repertoire.py
158
159
160
def get_number_genotypes(self) -> jax.Array:
    """Returns the number of genotypes in the repertoire."""
    return jnp.sum(self.fitnesses != -jnp.inf)

init(genotypes, fitnesses, descriptors, l_value, max_size, *args, extra_scores=None, keys_extra_scores=(), **kwargs) classmethod

Initialize a Map-Elites repertoire with an initial population of genotypes. Requires the definition of centroids that can be computed with any method such as CVT or Euclidean mapping.

Parameters:
  • genotypes (Genotype) –

    initial genotypes, pytree in which leaves have shape (batch_size, num_features)

  • fitnesses (Fitness) –

    fitness of the initial genotypes of shape (batch_size,)

  • descriptors (Descriptor) –

    descriptors of the initial genotypes of shape (batch_size, num_descriptors)

  • l_value (Array) –

    threshold distance of the repertoire.

  • max_size (int) –

    maximal size of the container

  • extra_scores (Optional[ExtraScores], default: None ) –

    extra scores resulting from the evaluation of the genotypes

  • keys_extra_scores (Tuple[str, ...], default: () ) –

    keys of the extra scores to store in the repertoire

Returns:
Source code in qdax/core/containers/unstructured_repertoire.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
@classmethod
def init(  # type: ignore
    cls,
    genotypes: Genotype,
    fitnesses: Fitness,
    descriptors: Descriptor,
    l_value: jax.Array,
    max_size: int,
    *args,
    extra_scores: Optional[ExtraScores] = None,
    keys_extra_scores: Tuple[str, ...] = (),
    **kwargs,
) -> UnstructuredRepertoire:
    """Initialize a Map-Elites repertoire with an initial population of genotypes.
    Requires the definition of centroids that can be computed with any method
    such as CVT or Euclidean mapping.

    Args:
        genotypes: initial genotypes, pytree in which leaves
            have shape (batch_size, num_features)
        fitnesses: fitness of the initial genotypes of shape (batch_size,)
        descriptors: descriptors of the initial genotypes
            of shape (batch_size, num_descriptors)
        l_value: threshold distance of the repertoire.
        max_size: maximal size of the container
        extra_scores: extra scores resulting from the evaluation of the genotypes
        keys_extra_scores: keys of the extra scores to store in the repertoire

    Returns:
        an initialized unstructured repertoire.
    """

    if extra_scores is None:
        extra_scores = {}

    # Initialize grid with default values
    default_fitnesses = -jnp.inf * jnp.ones(shape=(max_size, 1))
    default_genotypes = jax.tree.map(
        lambda x: jnp.full(shape=(max_size,) + x.shape[1:], fill_value=jnp.nan),
        genotypes,
    )
    default_descriptors = jnp.zeros(shape=(max_size, descriptors.shape[-1]))

    # create default extra scores
    filtered_extra_scores = {
        key: value
        for key, value in extra_scores.items()
        if key in keys_extra_scores
    }

    default_extra_scores = jax.tree.map(
        lambda x: jnp.zeros(shape=(max_size,) + x.shape[1:]),
        filtered_extra_scores,
    )

    repertoire = UnstructuredRepertoire(
        genotypes=default_genotypes,
        fitnesses=default_fitnesses,
        descriptors=default_descriptors,
        l_value=jnp.full(shape=(max_size,), fill_value=l_value),
        max_size=max_size,
        extra_scores=default_extra_scores,
        keys_extra_scores=keys_extra_scores,
    )

    return repertoire.add(  # type: ignore
        genotypes,
        descriptors,
        fitnesses,
        extra_scores,
    )

select(key, num_samples, selector=None)

Select elements in the repertoire.

This method sample a non-empty pareto front, and then sample genotypes from this pareto front.

Parameters:
  • key (RNGKey) –

    a random key to handle stochasticity.

  • num_samples (int) –

    number of samples to retrieve from the repertoire.

  • selector (Optional[Selector[UnstructuredRepertoire]], default: None ) –

    selector to choose the individuals. Defaults to None.

Returns:
Source code in qdax/core/containers/unstructured_repertoire.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def select(
    self,
    key: RNGKey,
    num_samples: int,
    selector: Optional[Selector[UnstructuredRepertoire]] = None,
) -> UnstructuredRepertoire:
    """Select elements in the repertoire.

    This method sample a non-empty pareto front, and then sample
    genotypes from this pareto front.

    Args:
        key: a random key to handle stochasticity.
        num_samples: number of samples to retrieve from the repertoire.
        selector: selector to choose the individuals. Defaults to None.

    Returns:
        A repertoire containing the selected individuals.
    """

    if selector is None:
        selector = UniformSelector(select_with_replacement=True)

    # Explicitly cast return value to UnstructuredRepertoire
    repertoire: UnstructuredRepertoire = selector.select(self, key, num_samples)

    return repertoire