Optimizing multiple objectives with MOME in Jax¶
This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using Multi-Objective MAP-Elites (MOME) algorithm. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:
- how to define the problem
- how to create an emitter instance
- how to create a Multi Objective Map-Elites instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualise the optimization process
import jax.numpy as jnp
import jax
from typing import Tuple
from functools import partial
try:
import brax
except:
!pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
import brax
try:
import flax
except:
!pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
import flax
try:
import chex
except:
!pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
import chex
try:
import jumanji
except:
!pip install "jumanji==0.3.1"
import jumanji
try:
import qdax
except:
!pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
import qdax
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.mome import MOME
from qdax.core.emitters.mutation_operators import (
polynomial_mutation,
polynomial_crossover,
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts
from qdax.utils.metrics import default_moqd_metrics
import matplotlib.pyplot as plt
from qdax.types import Fitness, Descriptor, RNGKey, ExtraScores
Set the hyperparameters¶
#@markdown ---
pareto_front_max_length = 50 #@param {type:"integer"}
num_variables = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
num_centroids = 64 #@param {type:"integer"}
minval = -2 #@param {type:"number"}
maxval = 4 #@param {type:"number"}
proportion_to_mutate = 0.6 #@param {type:"number"}
eta = 1 #@param {type:"number"}
proportion_var_to_change = 0.5 #@param {type:"number"}
crossover_percentage = 1. #@param {type:"number"}
batch_size = 100 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
#@markdown ---
Define the scoring function: rastrigin multi-objective¶
We use two rastrigin functions with an offset to create a multi-objective problem.
def rastrigin_scorer(
genotypes: jnp.ndarray, base_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
"""
Rastrigin Scorer with first two dimensions as descriptors
"""
descriptors = genotypes[:, :2]
f1 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - base_lag) ** 2
- 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
axis=1,
)
)
f2 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
axis=1,
)
)
scores = jnp.stack([f1, f2], axis=-1)
return scores, descriptors
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)
def scoring_fn(genotypes: jnp.ndarray, random_key: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
fitnesses, descriptors = scoring_function(genotypes)
return fitnesses, descriptors, {}, random_key
Define the metrics function that will be used¶
reference_point = jnp.array([ -150, -150])
# how to compute metrics from a repertoire
metrics_function = partial(
default_moqd_metrics,
reference_point=reference_point
)
Define the initial population and the emitter¶
# initial population
random_key = jax.random.PRNGKey(42)
random_key, subkey = jax.random.split(random_key)
init_genotypes = jax.random.uniform(
random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32
)
# crossover function
crossover_function = partial(
polynomial_crossover,
proportion_var_to_change=proportion_var_to_change
)
# mutation function
mutation_function = partial(
polynomial_mutation,
eta=eta,
minval=minval,
maxval=maxval,
proportion_to_mutate=proportion_to_mutate
)
# Define emitter
mixing_emitter = MixingEmitter(
mutation_fn=mutation_function,
variation_fn=crossover_function,
variation_percentage=crossover_percentage,
batch_size=batch_size
)
Compute the centroids¶
centroids, random_key = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=20000,
num_centroids=num_centroids,
minval=minval,
maxval=maxval,
random_key=random_key,
)
Define a MOME instance¶
mome = MOME(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
)
Init the algorithm¶
repertoire, emitter_state, random_key = mome.init(
init_genotypes,
centroids,
pareto_front_max_length,
random_key
)
Run MOME iterations¶
%%time
# Run the algorithm
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
mome.scan_update,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)
Plot the results¶
moqd_scores = jnp.sum(metrics["moqd_score"], where=metrics["moqd_score"] != -jnp.inf, axis=-1)
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))
steps = batch_size * jnp.arange(start=0, stop=num_iterations)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')
ax2.plot(steps, metrics["max_hypervolume"])
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')
ax3.plot(steps, metrics["max_sum_scores"])
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')
ax4.plot(steps, metrics["coverage"])
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
# plot pareto fronts
axes = plot_mome_pareto_fronts(
centroids,
repertoire,
minval=minval,
maxval=maxval,
color_style='spectral',
axes=axes,
with_global=True
)
# add map elites plot on last axe
plot_2d_map_elites_repertoire(
centroids=centroids,
repertoire_fitnesses=metrics["moqd_score"][-1],
minval=minval,
maxval=maxval,
ax=axes[2]
)
plt.show()