SPEA2 class

Bases: GeneticAlgorithm

Implements main functions of the SPEA2 algorithm.

This class inherits most functions from GeneticAlgorithm. The init function is overwritten in order to precise the type of repertoire used in SPEA2.

Link to paper: "https://www.semanticscholar.org/paper/SPEA2%3A- Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/ b13724c"

Source code in qdax/baselines/spea2.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class SPEA2(GeneticAlgorithm):
    """Implements main functions of the SPEA2 algorithm.

    This class inherits most functions from GeneticAlgorithm.
    The init function is overwritten in order to precise the type
    of repertoire used in SPEA2.

    Link to paper: "https://www.semanticscholar.org/paper/SPEA2%3A-
    Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/
    b13724cb54ae4171916f3f969d304b9e9752a57f"
    """

    def init(  # type: ignore
        self,
        genotypes: Genotype,
        population_size: int,
        num_neighbours: int,
        key: RNGKey,
    ) -> Tuple[SPEA2Repertoire, Optional[EmitterState], Metrics]:

        # score initial genotypes
        key, subkey = jax.random.split(key)
        fitnesses, extra_scores = self._scoring_function(genotypes, subkey)

        # init the repertoire
        repertoire = SPEA2Repertoire.init(
            genotypes=genotypes,
            fitnesses=fitnesses,
            population_size=population_size,
            num_neighbours=num_neighbours,
        )

        # get initial state of the emitter
        key, subkey = jax.random.split(key)
        emitter_state = self._emitter.init(
            key=subkey,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=None,
            extra_scores=extra_scores,
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            extra_scores=extra_scores,
        )

        # calculate the initial metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics