Optimizing multiple objectives with NSGA2 & SPEA2 in Jax¶
This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using NSGA2 and SPEA2 algorithms. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:
- how to define the problem
- how to create an emitter instance
- how to create an NSGA2 instance
- how to create an SPEA2 instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualise the optimization process
import jax.numpy as jnp
import jax
from typing import Tuple
import matplotlib.pyplot as plt
from functools import partial
try:
import brax
except:
!pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
import brax
try:
import flax
except:
!pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
import flax
try:
import chex
except:
!pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
import chex
try:
import jumanji
except:
!pip install "jumanji==0.3.1"
import jumanji
try:
import qdax
except:
!pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
import qdax
from qdax.baselines.nsga2 import (
NSGA2
)
from qdax.baselines.spea2 import (
SPEA2
)
from qdax.core.emitters.mutation_operators import (
polynomial_crossover,
polynomial_mutation
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front
from qdax.utils.pareto_front import compute_pareto_front
from qdax.utils.plotting import plot_global_pareto_front
from qdax.utils.metrics import default_ga_metrics
from qdax.types import Genotype, Fitness, Descriptor
Set the hyperparameters¶
#@markdown ---
population_size = 1000 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
proportion_mutation = 0.80 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
batch_size = 100 #@param {type:"integer"}
genotype_dim = 6 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
# for spea2
num_neighbours=1 #@param {type:"integer"}
#@markdown ---
Define the scoring function: rastrigin multi-objective¶
We use two rastrigin functions with an offset to create a multi-objective problem.
def rastrigin_scorer(
genotypes: jnp.ndarray, base_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
"""
Rastrigin Scorer with first two dimensions as descriptors
"""
descriptors = genotypes[:, :2]
f1 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - base_lag) ** 2
- 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
axis=1,
)
)
f2 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
axis=1,
)
)
scores = jnp.stack([f1, f2], axis=-1)
return scores, descriptors
# Scoring function
scoring_function = partial(
rastrigin_scorer,
lag=lag,
base_lag=base_lag
)
def scoring_fn(x, random_key):
return scoring_function(x)[0], {}, random_key
Define initial population and emitter¶
# Initial population
random_key = jax.random.PRNGKey(0)
random_key, subkey = jax.random.split(random_key)
init_genotypes = jax.random.uniform(
subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32
)
# Mutation & Crossover
crossover_function = partial(
polynomial_crossover,
proportion_var_to_change=0.5,
)
mutation_function = partial(
polynomial_mutation,
proportion_to_mutate=0.5,
eta=0.05,
minval=minval,
maxval=maxval
)
# Define the emitter
mixing_emitter = MixingEmitter(
mutation_fn=mutation_function,
variation_fn=crossover_function,
variation_percentage=1-proportion_mutation,
batch_size=batch_size
)
Instantiate and init NSGA2¶
# instantitiate nsga2
nsga2 = NSGA2(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=default_ga_metrics
)
# init nsga2
repertoire, emitter_state, random_key = nsga2.init(
init_genotypes,
population_size,
random_key
)
Run and visualize result¶
%%time
# run optimization loop
(repertoire, emitter_state, random_key), _ = jax.lax.scan(
nsga2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations
)
fig, ax = plt.subplots(figsize=(9, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title("Pareto front obtained by NSGA2", fontsize=16)
ax.set_xlabel("Fitness Dimension 1", fontsize=14)
ax.set_ylabel("Fitness Dimension 2", fontsize=14)
plt.grid()
plt.show()
Instantiate and init SPEA2¶
# instantitiate spea2
spea2 = SPEA2(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=default_ga_metrics
)
# init spea2
repertoire, emitter_state, random_key = spea2.init(
init_genotypes,
population_size,
num_neighbours,
random_key
)
%%time
# run optimization loop
(repertoire, emitter_state, random_key), _ = jax.lax.scan(
spea2.scan_update, (repertoire, emitter_state, random_key), (), length=num_iterations
)
Run and visualize result¶
fig, ax = plt.subplots(figsize=(9, 6))
pareto_bool = compute_pareto_front(repertoire.fitnesses)
plot_global_pareto_front(repertoire.fitnesses[pareto_bool], ax=ax)
ax.set_title("Pareto front obtained by SPEA2", fontsize=16)
ax.set_xlabel("Fitness Dimension 1", fontsize=14)
ax.set_ylabel("Fitness Dimension 2", fontsize=14)
plt.grid()
plt.show()