Genetic Algorithm class

Core class of a genetic algorithm.

This class implements default methods to run a simple genetic algorithm with a simple repertoire.

Parameters:
  • scoring_function (Callable[[Genotype, RNGKey], Tuple[Fitness, ExtraScores]]) –

    a function that takes a batch of genotypes and compute their fitnesses

  • emitter (Emitter) –

    an emitter is used to suggest offsprings given a repertoire. It has two compulsory functions. A function that takes emits a new population, and a function that update the internal state of the emitter

  • metrics_function (Callable[[GARepertoire], Metrics]) –

    a function that takes a repertoire and compute any useful metric to track its evolution

Source code in qdax/baselines/genetic_algorithm.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class GeneticAlgorithm:
    """Core class of a genetic algorithm.

    This class implements default methods to run a simple
    genetic algorithm with a simple repertoire.

    Args:
        scoring_function: a function that takes a batch of genotypes and compute
            their fitnesses
        emitter: an emitter is used to suggest offsprings given a repertoire. It has
            two compulsory functions. A function that takes emits a new population, and
            a function that update the internal state of the emitter
        metrics_function: a function that takes a repertoire and compute any useful
            metric to track its evolution
    """

    def __init__(
        self,
        scoring_function: Callable[[Genotype, RNGKey], Tuple[Fitness, ExtraScores]],
        emitter: Emitter,
        metrics_function: Callable[[GARepertoire], Metrics],
    ) -> None:
        self._scoring_function = scoring_function
        self._emitter = emitter
        self._metrics_function = metrics_function

    def init(
        self, genotypes: Genotype, population_size: int, key: RNGKey
    ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]:
        """Initialize a GARepertoire with an initial population of genotypes.

        Args:
            genotypes: the initial population of genotypes
            population_size: the maximal size of the repertoire
            key: a random key to handle stochastic operations

        Returns:
            The initial repertoire, an initial emitter state and a new random key.
        """

        # score initial genotypes
        key, subkey = jax.random.split(key)
        fitnesses, extra_scores = self._scoring_function(genotypes, subkey)

        # init the repertoire
        repertoire = GARepertoire.init(
            genotypes=genotypes,
            fitnesses=fitnesses,
            population_size=population_size,
        )

        # get initial state of the emitter
        key, subkey = jax.random.split(key)
        emitter_state = self._emitter.init(
            key=subkey,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=None,
            extra_scores=extra_scores,
        )

        # calculate the initial metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

    def update(
        self,
        repertoire: GARepertoire,
        emitter_state: Optional[EmitterState],
        key: RNGKey,
    ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]:
        """
        Performs one iteration of a Genetic algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.

        Args:
            repertoire: a repertoire
            emitter_state: state of the emitter
            key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """

        # generate offsprings
        key, subkey = jax.random.split(key)
        genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)

        # score the offsprings
        key, subkey = jax.random.split(key)
        fitnesses, extra_scores = self._scoring_function(genotypes, subkey)

        # update the repertoire
        repertoire = repertoire.add(genotypes, fitnesses)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=None,
            extra_scores={**extra_scores, **extra_info},
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics

    def scan_update(
        self,
        carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
        _: Any,
    ) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive.

        Args:
            carry: a tuple containing the repertoire, the emitter state and a
                random key.
            _: unused element, necessary to respect jax.lax.scan API.

        Returns:
            The updated repertoire and emitter state, with a new random key and metrics.
        """
        # iterate over grid
        repertoire, emitter_state, key = carry
        key, subkey = jax.random.split(key)
        repertoire, emitter_state, metrics = self.update(
            repertoire, emitter_state, subkey
        )

        return (repertoire, emitter_state, key), metrics

init(genotypes, population_size, key)

Initialize a GARepertoire with an initial population of genotypes.

Parameters:
  • genotypes (Genotype) –

    the initial population of genotypes

  • population_size (int) –

    the maximal size of the repertoire

  • key (RNGKey) –

    a random key to handle stochastic operations

Returns:
  • Tuple[GARepertoire, Optional[EmitterState], Metrics]

    The initial repertoire, an initial emitter state and a new random key.

Source code in qdax/baselines/genetic_algorithm.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def init(
    self, genotypes: Genotype, population_size: int, key: RNGKey
) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]:
    """Initialize a GARepertoire with an initial population of genotypes.

    Args:
        genotypes: the initial population of genotypes
        population_size: the maximal size of the repertoire
        key: a random key to handle stochastic operations

    Returns:
        The initial repertoire, an initial emitter state and a new random key.
    """

    # score initial genotypes
    key, subkey = jax.random.split(key)
    fitnesses, extra_scores = self._scoring_function(genotypes, subkey)

    # init the repertoire
    repertoire = GARepertoire.init(
        genotypes=genotypes,
        fitnesses=fitnesses,
        population_size=population_size,
    )

    # get initial state of the emitter
    key, subkey = jax.random.split(key)
    emitter_state = self._emitter.init(
        key=subkey,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=None,
        extra_scores=extra_scores,
    )

    # calculate the initial metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics

scan_update(carry, _)

Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.

Parameters:
  • carry (Tuple[GARepertoire, Optional[EmitterState], RNGKey]) –

    a tuple containing the repertoire, the emitter state and a random key.

  • _ (Any) –

    unused element, necessary to respect jax.lax.scan API.

Returns:
  • Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]

    The updated repertoire and emitter state, with a new random key and metrics.

Source code in qdax/baselines/genetic_algorithm.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def scan_update(
    self,
    carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
    _: Any,
) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
    """Rewrites the update function in a way that makes it compatible with the
    jax.lax.scan primitive.

    Args:
        carry: a tuple containing the repertoire, the emitter state and a
            random key.
        _: unused element, necessary to respect jax.lax.scan API.

    Returns:
        The updated repertoire and emitter state, with a new random key and metrics.
    """
    # iterate over grid
    repertoire, emitter_state, key = carry
    key, subkey = jax.random.split(key)
    repertoire, emitter_state, metrics = self.update(
        repertoire, emitter_state, subkey
    )

    return (repertoire, emitter_state, key), metrics

update(repertoire, emitter_state, key)

Performs one iteration of a Genetic algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the repertoire.

Parameters:
  • repertoire (GARepertoire) –

    a repertoire

  • emitter_state (Optional[EmitterState]) –

    state of the emitter

  • key (RNGKey) –

    a jax PRNG random key

Returns:
  • GARepertoire

    the updated MAP-Elites repertoire

  • Optional[EmitterState]

    the updated (if needed) emitter state

  • Metrics

    metrics about the updated repertoire

  • Tuple[GARepertoire, Optional[EmitterState], Metrics]

    a new jax PRNG key

Source code in qdax/baselines/genetic_algorithm.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def update(
    self,
    repertoire: GARepertoire,
    emitter_state: Optional[EmitterState],
    key: RNGKey,
) -> Tuple[GARepertoire, Optional[EmitterState], Metrics]:
    """
    Performs one iteration of a Genetic algorithm.
    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.

    Args:
        repertoire: a repertoire
        emitter_state: state of the emitter
        key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """

    # generate offsprings
    key, subkey = jax.random.split(key)
    genotypes, extra_info = self._emitter.emit(repertoire, emitter_state, subkey)

    # score the offsprings
    key, subkey = jax.random.split(key)
    fitnesses, extra_scores = self._scoring_function(genotypes, subkey)

    # update the repertoire
    repertoire = repertoire.add(genotypes, fitnesses)

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=None,
        extra_scores={**extra_scores, **extra_info},
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics