QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
    • Installation
    • Set the hyperparameters
    • Define the scoring function: rastrigin multi-objective
    • Define initial population and emitter
    • Instantiate and init NSGA2
    • Run and visualize result
    • Instantiate and init SPEA2
    • Run and visualize result
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Edit on QDax

Open In Colab

Optimizing multiple objectives with NSGA2 & SPEA2 in JAX¶

This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using NSGA2 and SPEA2 algorithms. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create an emitter instance
  • how to create an NSGA2 instance
  • how to create an SPEA2 instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax.numpy as jnp
import jax

from typing import Tuple

import matplotlib.pyplot as plt

from functools import partial

from qdax.baselines.nsga2 import (
    NSGA2
)
from qdax.baselines.spea2 import (
    SPEA2
)

from qdax.core.emitters.mutation_operators import (
    polynomial_crossover,
    polynomial_mutation
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front

from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front
from qdax.utils.metrics import default_ga_metrics

from qdax.custom_types import Genotype, Fitness, Descriptor
import jax.numpy as jnp import jax from typing import Tuple import matplotlib.pyplot as plt from functools import partial from qdax.baselines.nsga2 import ( NSGA2 ) from qdax.baselines.spea2 import ( SPEA2 ) from qdax.core.emitters.mutation_operators import ( polynomial_crossover, polynomial_mutation ) from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.utils.pareto_front import compute_pareto_front from qdax.utils.plotting import plot_global_pareto_front from qdax.utils.pareto_front import compute_pareto_front from qdax.utils.plotting import plot_global_pareto_front from qdax.utils.metrics import default_ga_metrics from qdax.custom_types import Genotype, Fitness, Descriptor

Set the hyperparameters¶

In [ ]:
Copied!
#@markdown ---
population_size = 1000 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
proportion_mutation = 0.80 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
genotype_dim = 6 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
# for spea2
num_neighbours=1 #@param {type:"integer"}
#@markdown ---
#@markdown --- population_size = 1000 #@param {type:"integer"} num_iterations = 1000 #@param {type:"integer"} proportion_mutation = 0.80 #@param {type:"number"} minval = -5.12 #@param {type:"number"} maxval = 5.12 #@param {type:"number"} genotype_dim = 6 #@param {type:"integer"} lag = 2.2 #@param {type:"number"} base_lag = 0 #@param {type:"number"} # for spea2 num_neighbours=1 #@param {type:"integer"} #@markdown ---

Define the scoring function: rastrigin multi-objective¶

We use two rastrigin functions with an offset to create a multi-objective problem.

In [ ]:
Copied!
def rastrigin_scorer(
    genotypes: jax.Arraybase_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
    """
    Rastrigin Scorer with first two dimensions as descriptors
    """
    descriptors = genotypes[:, :2]
    f1 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - base_lag) ** 2
            - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
            axis=1,
        )
    )

    f2 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
            axis=1,
        )
    )
    scores = jnp.stack([f1, f2], axis=-1)

    return scores, descriptors
def rastrigin_scorer( genotypes: jax.Arraybase_lag: float, lag: float ) -> Tuple[Fitness, Descriptor]: """ Rastrigin Scorer with first two dimensions as descriptors """ descriptors = genotypes[:, :2] f1 = -( 10 * genotypes.shape[1] + jnp.sum( (genotypes - base_lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)), axis=1, ) ) f2 = -( 10 * genotypes.shape[1] + jnp.sum( (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)), axis=1, ) ) scores = jnp.stack([f1, f2], axis=-1) return scores, descriptors
In [ ]:
Copied!
# Scoring function
scoring_function = partial(
    rastrigin_scorer,
    lag=lag,
    base_lag=base_lag
)

def scoring_fn(x, key):
    return scoring_function(x)[0], {}
# Scoring function scoring_function = partial( rastrigin_scorer, lag=lag, base_lag=base_lag ) def scoring_fn(x, key): return scoring_function(x)[0], {}

Define initial population and emitter¶

In [ ]:
Copied!
# Initial population
key = jax.random.key(0)
key, subkey = jax.random.split(key)
genotypes = jax.random.uniform(
    subkey, (population_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32
)

# Mutation & Crossover
crossover_function = partial(
    polynomial_crossover, 
    proportion_var_to_change=0.5,
)

mutation_function = partial(
    polynomial_mutation, 
    proportion_to_mutate=0.5, 
    eta=0.05, 
    minval=minval, 
    maxval=maxval
)

# Define the emitter
# NSGA-II and SPEA2 use batch size = population size
mixing_emitter = MixingEmitter(
    mutation_fn=mutation_function, 
    variation_fn=crossover_function, 
    variation_percentage=1-proportion_mutation, 
    batch_size=population_size, 
)
# Initial population key = jax.random.key(0) key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (population_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32 ) # Mutation & Crossover crossover_function = partial( polynomial_crossover, proportion_var_to_change=0.5, ) mutation_function = partial( polynomial_mutation, proportion_to_mutate=0.5, eta=0.05, minval=minval, maxval=maxval ) # Define the emitter # NSGA-II and SPEA2 use batch size = population size mixing_emitter = MixingEmitter( mutation_fn=mutation_function, variation_fn=crossover_function, variation_percentage=1-proportion_mutation, batch_size=population_size, )

Instantiate and init NSGA2¶

In [ ]:
Copied!
# instantitiate nsga2
nsga2 = NSGA2(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics
)

# init nsga2
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = nsga2.init(
    genotypes,
    population_size,
    subkey
)
# instantitiate nsga2 nsga2 = NSGA2( scoring_function=scoring_fn, emitter=mixing_emitter, metrics_function=default_ga_metrics ) # init nsga2 key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = nsga2.init( genotypes, population_size, subkey )

Run and visualize result¶

In [ ]:
Copied!
# Run optimization loop
(repertoire, emitter_state, key), _ = jax.lax.scan(
    nsga2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations
)
# Run optimization loop (repertoire, emitter_state, key), _ = jax.lax.scan( nsga2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations )
In [ ]:
Copied!
fig, ax = plt.subplots(figsize=(9, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title("Pareto front obtained by NSGA2", fontsize=16)
ax.set_xlabel("Fitness Dimension 1", fontsize=14)
ax.set_ylabel("Fitness Dimension 2", fontsize=14)
plt.grid()
plt.show()
fig, ax = plt.subplots(figsize=(9, 6)) pareto_bool = compute_pareto_front(repertoire.fitnesses) plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax) ax.set_title("Pareto front obtained by NSGA2", fontsize=16) ax.set_xlabel("Fitness Dimension 1", fontsize=14) ax.set_ylabel("Fitness Dimension 2", fontsize=14) plt.grid() plt.show()

Instantiate and init SPEA2¶

In [ ]:
Copied!
# instantitiate spea2
spea2 = SPEA2(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=default_ga_metrics
)

# init spea2
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = spea2.init(
    genotypes,
    population_size,
    num_neighbours,
    subkey,
)
# instantitiate spea2 spea2 = SPEA2( scoring_function=scoring_fn, emitter=mixing_emitter, metrics_function=default_ga_metrics ) # init spea2 key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = spea2.init( genotypes, population_size, num_neighbours, subkey, )
In [ ]:
Copied!
# run optimization loop
(repertoire, emitter_state, key), _ = jax.lax.scan(
    spea2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations
)
# run optimization loop (repertoire, emitter_state, key), _ = jax.lax.scan( spea2.scan_update, (repertoire, emitter_state, key), (), length=num_iterations )

Run and visualize result¶

In [ ]:
Copied!
fig, ax = plt.subplots(figsize=(9, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title("Pareto front obtained by SPEA2", fontsize=16)
ax.set_xlabel("Fitness Dimension 1", fontsize=14)
ax.set_ylabel("Fitness Dimension 2", fontsize=14)
plt.grid()
plt.show()
fig, ax = plt.subplots(figsize=(9, 6)) pareto_bool = compute_pareto_front(repertoire.fitnesses) plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax) ax.set_title("Pareto front obtained by SPEA2", fontsize=16) ax.set_xlabel("Fitness Dimension 1", fontsize=14) ax.set_ylabel("Fitness Dimension 2", fontsize=14) plt.grid() plt.show()
Previous Next

Built with MkDocs using a theme provided by Read the Docs.