Optimizing multiple objectives with MOME in JAX¶
This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using Multi-Objective MAP-Elites (MOME) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:
- how to define the problem
- how to create an emitter instance
- how to create a Multi Objective Map-Elites instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualise the optimization process
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install QDax from PyPI:
In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax.numpy as jnp
import jax
from typing import Tuple
from functools import partial
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.mome import MOME
from qdax.core.emitters.mutation_operators import (
polynomial_mutation,
polynomial_crossover,
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts
from qdax.utils.metrics import default_moqd_metrics
import matplotlib.pyplot as plt
from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores
import jax.numpy as jnp
import jax
from typing import Tuple
from functools import partial
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.mome import MOME
from qdax.core.emitters.mutation_operators import (
polynomial_mutation,
polynomial_crossover,
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts
from qdax.utils.metrics import default_moqd_metrics
import matplotlib.pyplot as plt
from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores
Set the hyperparameters¶
In [ ]:
Copied!
#@markdown ---
pareto_front_max_length = 50 #@param {type:"integer"}
num_variables = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
num_centroids = 64 #@param {type:"integer"}
minval = -2 #@param {type:"number"}
maxval = 4 #@param {type:"number"}
proportion_to_mutate = 0.6 #@param {type:"number"}
eta = 1 #@param {type:"number"}
proportion_var_to_change = 0.5 #@param {type:"number"}
crossover_percentage = 1. #@param {type:"number"}
batch_size = 100 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
#@markdown ---
#@markdown ---
pareto_front_max_length = 50 #@param {type:"integer"}
num_variables = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
num_centroids = 64 #@param {type:"integer"}
minval = -2 #@param {type:"number"}
maxval = 4 #@param {type:"number"}
proportion_to_mutate = 0.6 #@param {type:"number"}
eta = 1 #@param {type:"number"}
proportion_var_to_change = 0.5 #@param {type:"number"}
crossover_percentage = 1. #@param {type:"number"}
batch_size = 100 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
#@markdown ---
Define the scoring function: rastrigin multi-objective¶
We use two rastrigin functions with an offset to create a multi-objective problem.
In [ ]:
Copied!
def rastrigin_scorer(
genotypes: jax.Arraybase_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
"""
Rastrigin Scorer with first two dimensions as descriptors
"""
descriptors = genotypes[:, :2]
f1 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - base_lag) ** 2
- 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
axis=1,
)
)
f2 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
axis=1,
)
)
scores = jnp.stack([f1, f2], axis=-1)
return scores, descriptors
def rastrigin_scorer(
genotypes: jax.Arraybase_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
"""
Rastrigin Scorer with first two dimensions as descriptors
"""
descriptors = genotypes[:, :2]
f1 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - base_lag) ** 2
- 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
axis=1,
)
)
f2 = -(
10 * genotypes.shape[1]
+ jnp.sum(
(genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
axis=1,
)
)
scores = jnp.stack([f1, f2], axis=-1)
return scores, descriptors
In [ ]:
Copied!
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)
def scoring_fn(genotypes: jax.Arraykey: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
fitnesses, descriptors = scoring_function(genotypes)
return fitnesses, descriptors, {}
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)
def scoring_fn(genotypes: jax.Arraykey: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
fitnesses, descriptors = scoring_function(genotypes)
return fitnesses, descriptors, {}
Define the metrics function that will be used¶
In [ ]:
Copied!
reference_point = jnp.array([ -150, -150])
# how to compute metrics from a repertoire
metrics_function = partial(
default_moqd_metrics,
reference_point=reference_point
)
reference_point = jnp.array([ -150, -150])
# how to compute metrics from a repertoire
metrics_function = partial(
default_moqd_metrics,
reference_point=reference_point
)
Define the initial population and the emitter¶
In [ ]:
Copied!
# initial population
key = jax.random.key(42)
key, subkey = jax.random.split(key)
genotypes = jax.random.uniform(
subkey, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32
)
# crossover function
crossover_function = partial(
polynomial_crossover,
proportion_var_to_change=proportion_var_to_change
)
# mutation function
mutation_function = partial(
polynomial_mutation,
eta=eta,
minval=minval,
maxval=maxval,
proportion_to_mutate=proportion_to_mutate
)
# Define emitter
mixing_emitter = MixingEmitter(
mutation_fn=mutation_function,
variation_fn=crossover_function,
variation_percentage=crossover_percentage,
batch_size=batch_size
)
# initial population
key = jax.random.key(42)
key, subkey = jax.random.split(key)
genotypes = jax.random.uniform(
subkey, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32
)
# crossover function
crossover_function = partial(
polynomial_crossover,
proportion_var_to_change=proportion_var_to_change
)
# mutation function
mutation_function = partial(
polynomial_mutation,
eta=eta,
minval=minval,
maxval=maxval,
proportion_to_mutate=proportion_to_mutate
)
# Define emitter
mixing_emitter = MixingEmitter(
mutation_fn=mutation_function,
variation_fn=crossover_function,
variation_percentage=crossover_percentage,
batch_size=batch_size
)
Compute the centroids¶
In [ ]:
Copied!
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=20000,
num_centroids=num_centroids,
minval=minval,
maxval=maxval,
key=subkey,
)
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=20000,
num_centroids=num_centroids,
minval=minval,
maxval=maxval,
key=subkey,
)
Define a MOME instance¶
In [ ]:
Copied!
mome = MOME(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
pareto_front_max_length=pareto_front_max_length,
)
mome = MOME(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
pareto_front_max_length=pareto_front_max_length,
)
Init the algorithm¶
In [ ]:
Copied!
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = mome.init(
genotypes,
centroids,
subkey
)
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = mome.init(
genotypes,
centroids,
subkey
)
Run MOME iterations¶
In [ ]:
Copied!
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
mome.scan_update,
(repertoire, emitter_state, key),
(),
length=num_iterations,
)
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
mome.scan_update,
(repertoire, emitter_state, key),
(),
length=num_iterations,
)
Plot the results¶
In [ ]:
Copied!
moqd_scores = jnp.sum(metrics["moqd_score"], where=metrics["moqd_score"] != -jnp.inf, axis=-1)
moqd_scores = jnp.sum(metrics["moqd_score"], where=metrics["moqd_score"] != -jnp.inf, axis=-1)
In [ ]:
Copied!
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))
steps = batch_size * jnp.arange(start=0, stop=num_iterations)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')
ax2.plot(steps, metrics["max_hypervolume"])
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')
ax3.plot(steps, metrics["max_sum_scores"])
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')
ax4.plot(steps, metrics["coverage"])
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))
steps = batch_size * jnp.arange(start=0, stop=num_iterations)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')
ax2.plot(steps, metrics["max_hypervolume"])
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')
ax3.plot(steps, metrics["max_sum_scores"])
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')
ax4.plot(steps, metrics["coverage"])
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()
In [ ]:
Copied!
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
# plot pareto fronts
axes = plot_mome_pareto_fronts(
centroids,
repertoire,
minval=minval,
maxval=maxval,
color_style='spectral',
axes=axes,
with_global=True
)
# add map elites plot on last axe
plot_2d_map_elites_repertoire(
centroids=centroids,
repertoire_fitnesses=metrics["moqd_score"][-1],
minval=minval,
maxval=maxval,
ax=axes[2]
)
plt.show()
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
# plot pareto fronts
axes = plot_mome_pareto_fronts(
centroids,
repertoire,
minval=minval,
maxval=maxval,
color_style='spectral',
axes=axes,
with_global=True
)
# add map elites plot on last axe
plot_2d_map_elites_repertoire(
centroids=centroids,
repertoire_fitnesses=metrics["moqd_score"][-1],
minval=minval,
maxval=maxval,
ax=axes[2]
)
plt.show()