Genetic Algorithm class

qdax.baselines.genetic_algorithm.GeneticAlgorithm

Core class of a genetic algorithm.

This class implements default methods to run a simple genetic algorithm with a simple repertoire.

Parameters:
  • scoring_function (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array]]) – a function that takes a batch of genotypes and compute their fitnesses

  • emitter (Emitter) – an emitter is used to suggest offsprings given a repertoire. It has two compulsory functions. A function that takes emits a new population, and a function that update the internal state of the emitter

  • metrics_function (Callable[[qdax.core.containers.ga_repertoire.GARepertoire], Dict[str, jax.Array]]) – a function that takes a repertoire and compute any useful metric to track its evolution

Source code in qdax/baselines/genetic_algorithm.py
class GeneticAlgorithm:
    """Core class of a genetic algorithm.

    This class implements default methods to run a simple
    genetic algorithm with a simple repertoire.

    Args:
        scoring_function: a function that takes a batch of genotypes and compute
            their fitnesses
        emitter: an emitter is used to suggest offsprings given a repertoire. It has
            two compulsory functions. A function that takes emits a new population, and
            a function that update the internal state of the emitter
        metrics_function: a function that takes a repertoire and compute any useful
            metric to track its evolution
    """

    def __init__(
        self,
        scoring_function: Callable[
            [Genotype, RNGKey], Tuple[Fitness, ExtraScores, RNGKey]
        ],
        emitter: Emitter,
        metrics_function: Callable[[GARepertoire], Metrics],
    ) -> None:
        self._scoring_function = scoring_function
        self._emitter = emitter
        self._metrics_function = metrics_function

    @partial(jax.jit, static_argnames=("self", "population_size"))
    def init(
        self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
    ) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]:
        """Initialize a GARepertoire with an initial population of genotypes.

        Args:
            init_genotypes: the initial population of genotypes
            population_size: the maximal size of the repertoire
            random_key: a random key to handle stochastic operations

        Returns:
            The initial repertoire, an initial emitter state and a new random key.
        """

        # score initial genotypes
        fitnesses, extra_scores, random_key = self._scoring_function(
            init_genotypes, random_key
        )

        # init the repertoire
        repertoire = GARepertoire.init(
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            population_size=population_size,
        )

        # get initial state of the emitter
        emitter_state, random_key = self._emitter.init(
            init_genotypes=init_genotypes, random_key=random_key
        )

        # update emitter state
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=init_genotypes,
            fitnesses=fitnesses,
            descriptors=None,
            extra_scores=extra_scores,
        )

        return repertoire, emitter_state, random_key

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        repertoire: GARepertoire,
        emitter_state: Optional[EmitterState],
        random_key: RNGKey,
    ) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]:
        """
        Performs one iteration of a Genetic algorithm.
        1. A batch of genotypes is sampled in the repertoire and the genotypes
            are copied.
        2. The copies are mutated and crossed-over
        3. The obtained offsprings are scored and then added to the repertoire.

        Args:
            repertoire: a repertoire
            emitter_state: state of the emitter
            random_key: a jax PRNG random key

        Returns:
            the updated MAP-Elites repertoire
            the updated (if needed) emitter state
            metrics about the updated repertoire
            a new jax PRNG key
        """

        # generate offsprings
        genotypes, random_key = self._emitter.emit(
            repertoire, emitter_state, random_key
        )

        # score the offsprings
        fitnesses, extra_scores, random_key = self._scoring_function(
            genotypes, random_key
        )

        # update the repertoire
        repertoire = repertoire.add(genotypes, fitnesses)

        # update emitter state after scoring is made
        emitter_state = self._emitter.state_update(
            emitter_state=emitter_state,
            repertoire=repertoire,
            genotypes=genotypes,
            fitnesses=fitnesses,
            descriptors=None,
            extra_scores=extra_scores,
        )

        # update the metrics
        metrics = self._metrics_function(repertoire)

        return repertoire, emitter_state, metrics, random_key

    @partial(jax.jit, static_argnames=("self",))
    def scan_update(
        self,
        carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
        unused: Any,
    ) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
        """Rewrites the update function in a way that makes it compatible with the
        jax.lax.scan primitive.

        Args:
            carry: a tuple containing the repertoire, the emitter state and a
                random key.
            unused: unused element, necessary to respect jax.lax.scan API.

        Returns:
            The updated repertoire and emitter state, with a new random key and metrics.
        """
        # iterate over grid
        repertoire, emitter_state, random_key = carry
        repertoire, emitter_state, metrics, random_key = self.update(
            repertoire, emitter_state, random_key
        )

        return (repertoire, emitter_state, random_key), metrics

init(self, init_genotypes, population_size, random_key)

Initialize a GARepertoire with an initial population of genotypes.

Parameters:
  • init_genotypes (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the initial population of genotypes

  • population_size (int) – the maximal size of the repertoire

  • random_key (Array) – a random key to handle stochastic operations

Returns:
  • Tuple[qdax.core.containers.ga_repertoire.GARepertoire, Optional[qdax.core.emitters.emitter.EmitterState], jax.Array] – The initial repertoire, an initial emitter state and a new random key.

Source code in qdax/baselines/genetic_algorithm.py
@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
    self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]:
    """Initialize a GARepertoire with an initial population of genotypes.

    Args:
        init_genotypes: the initial population of genotypes
        population_size: the maximal size of the repertoire
        random_key: a random key to handle stochastic operations

    Returns:
        The initial repertoire, an initial emitter state and a new random key.
    """

    # score initial genotypes
    fitnesses, extra_scores, random_key = self._scoring_function(
        init_genotypes, random_key
    )

    # init the repertoire
    repertoire = GARepertoire.init(
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        population_size=population_size,
    )

    # get initial state of the emitter
    emitter_state, random_key = self._emitter.init(
        init_genotypes=init_genotypes, random_key=random_key
    )

    # update emitter state
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=init_genotypes,
        fitnesses=fitnesses,
        descriptors=None,
        extra_scores=extra_scores,
    )

    return repertoire, emitter_state, random_key

update(self, repertoire, emitter_state, random_key)

Performs one iteration of a Genetic algorithm. 1. A batch of genotypes is sampled in the repertoire and the genotypes are copied. 2. The copies are mutated and crossed-over 3. The obtained offsprings are scored and then added to the repertoire.

Parameters:
  • repertoire (GARepertoire) – a repertoire

  • emitter_state (Optional[qdax.core.emitters.emitter.EmitterState]) – state of the emitter

  • random_key (Array) – a jax PRNG random key

Returns:
  • Tuple[qdax.core.containers.ga_repertoire.GARepertoire, Optional[qdax.core.emitters.emitter.EmitterState], Dict[str, jax.Array], jax.Array] – the updated MAP-Elites repertoire the updated (if needed) emitter state metrics about the updated repertoire a new jax PRNG key

Source code in qdax/baselines/genetic_algorithm.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    repertoire: GARepertoire,
    emitter_state: Optional[EmitterState],
    random_key: RNGKey,
) -> Tuple[GARepertoire, Optional[EmitterState], Metrics, RNGKey]:
    """
    Performs one iteration of a Genetic algorithm.
    1. A batch of genotypes is sampled in the repertoire and the genotypes
        are copied.
    2. The copies are mutated and crossed-over
    3. The obtained offsprings are scored and then added to the repertoire.

    Args:
        repertoire: a repertoire
        emitter_state: state of the emitter
        random_key: a jax PRNG random key

    Returns:
        the updated MAP-Elites repertoire
        the updated (if needed) emitter state
        metrics about the updated repertoire
        a new jax PRNG key
    """

    # generate offsprings
    genotypes, random_key = self._emitter.emit(
        repertoire, emitter_state, random_key
    )

    # score the offsprings
    fitnesses, extra_scores, random_key = self._scoring_function(
        genotypes, random_key
    )

    # update the repertoire
    repertoire = repertoire.add(genotypes, fitnesses)

    # update emitter state after scoring is made
    emitter_state = self._emitter.state_update(
        emitter_state=emitter_state,
        repertoire=repertoire,
        genotypes=genotypes,
        fitnesses=fitnesses,
        descriptors=None,
        extra_scores=extra_scores,
    )

    # update the metrics
    metrics = self._metrics_function(repertoire)

    return repertoire, emitter_state, metrics, random_key

scan_update(self, carry, unused)

Rewrites the update function in a way that makes it compatible with the jax.lax.scan primitive.

Parameters:
  • carry (Tuple[qdax.core.containers.ga_repertoire.GARepertoire, Optional[qdax.core.emitters.emitter.EmitterState], jax.Array]) – a tuple containing the repertoire, the emitter state and a random key.

  • unused (Any) – unused element, necessary to respect jax.lax.scan API.

Returns:
  • Tuple[Tuple[qdax.core.containers.ga_repertoire.GARepertoire, Optional[qdax.core.emitters.emitter.EmitterState], jax.Array], Dict[str, jax.Array]] – The updated repertoire and emitter state, with a new random key and metrics.

Source code in qdax/baselines/genetic_algorithm.py
@partial(jax.jit, static_argnames=("self",))
def scan_update(
    self,
    carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
    unused: Any,
) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
    """Rewrites the update function in a way that makes it compatible with the
    jax.lax.scan primitive.

    Args:
        carry: a tuple containing the repertoire, the emitter state and a
            random key.
        unused: unused element, necessary to respect jax.lax.scan API.

    Returns:
        The updated repertoire and emitter state, with a new random key and metrics.
    """
    # iterate over grid
    repertoire, emitter_state, random_key = carry
    repertoire, emitter_state, metrics, random_key = self.update(
        repertoire, emitter_state, random_key
    )

    return (repertoire, emitter_state, random_key), metrics