QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
    • Installation
    • Set the hyperparameters
    • Define the scoring function: rastrigin multi-objective
    • Define the metrics function that will be used
    • Define the initial population and the emitter
    • Compute the centroids
    • Define a MOME instance
    • Init the algorithm
    • Run MOME iterations
    • Plot the results
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing multiple objectives with MOME in JAX
  • Edit on QDax

Open In Colab

Optimizing multiple objectives with MOME in JAX¶

This notebook shows how to use QDax to find diverse and performing parameters on a multi-objectives Rastrigin problem, using Multi-Objective MAP-Elites (MOME) algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create an emitter instance
  • how to create a Multi Objective Map-Elites instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax.numpy as jnp
import jax
from typing import Tuple

from functools import partial

from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.mome import MOME
from qdax.core.emitters.mutation_operators import (
    polynomial_mutation,
    polynomial_crossover,
)
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts

from qdax.utils.metrics import default_moqd_metrics

import matplotlib.pyplot as plt

from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores
import jax.numpy as jnp import jax from typing import Tuple from functools import partial from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.mome import MOME from qdax.core.emitters.mutation_operators import ( polynomial_mutation, polynomial_crossover, ) from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.utils.plotting import plot_2d_map_elites_repertoire, plot_mome_pareto_fronts from qdax.utils.metrics import default_moqd_metrics import matplotlib.pyplot as plt from qdax.custom_types import Fitness, Descriptor, RNGKey, ExtraScores

Set the hyperparameters¶

In [ ]:
Copied!
#@markdown ---
pareto_front_max_length = 50 #@param {type:"integer"}
num_variables = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}

num_centroids = 64 #@param {type:"integer"}
minval = -2 #@param {type:"number"}
maxval = 4 #@param {type:"number"}
proportion_to_mutate = 0.6 #@param {type:"number"}
eta = 1 #@param {type:"number"}
proportion_var_to_change = 0.5 #@param {type:"number"}
crossover_percentage = 1. #@param {type:"number"}
batch_size = 100 #@param {type:"integer"}
lag = 2.2 #@param {type:"number"}
base_lag = 0 #@param {type:"number"}
#@markdown ---
#@markdown --- pareto_front_max_length = 50 #@param {type:"integer"} num_variables = 100 #@param {type:"integer"} num_iterations = 1000 #@param {type:"integer"} num_centroids = 64 #@param {type:"integer"} minval = -2 #@param {type:"number"} maxval = 4 #@param {type:"number"} proportion_to_mutate = 0.6 #@param {type:"number"} eta = 1 #@param {type:"number"} proportion_var_to_change = 0.5 #@param {type:"number"} crossover_percentage = 1. #@param {type:"number"} batch_size = 100 #@param {type:"integer"} lag = 2.2 #@param {type:"number"} base_lag = 0 #@param {type:"number"} #@markdown ---

Define the scoring function: rastrigin multi-objective¶

We use two rastrigin functions with an offset to create a multi-objective problem.

In [ ]:
Copied!
def rastrigin_scorer(
    genotypes: jax.Arraybase_lag: float, lag: float
) -> Tuple[Fitness, Descriptor]:
    """
    Rastrigin Scorer with first two dimensions as descriptors
    """
    descriptors = genotypes[:, :2]
    f1 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - base_lag) ** 2
            - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)),
            axis=1,
        )
    )

    f2 = -(
        10 * genotypes.shape[1]
        + jnp.sum(
            (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)),
            axis=1,
        )
    )
    scores = jnp.stack([f1, f2], axis=-1)

    return scores, descriptors
def rastrigin_scorer( genotypes: jax.Arraybase_lag: float, lag: float ) -> Tuple[Fitness, Descriptor]: """ Rastrigin Scorer with first two dimensions as descriptors """ descriptors = genotypes[:, :2] f1 = -( 10 * genotypes.shape[1] + jnp.sum( (genotypes - base_lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - base_lag)), axis=1, ) ) f2 = -( 10 * genotypes.shape[1] + jnp.sum( (genotypes - lag) ** 2 - 10 * jnp.cos(2 * jnp.pi * (genotypes - lag)), axis=1, ) ) scores = jnp.stack([f1, f2], axis=-1) return scores, descriptors
In [ ]:
Copied!
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag)

def scoring_fn(genotypes: jax.Arraykey: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
    fitnesses, descriptors = scoring_function(genotypes)
    return fitnesses, descriptors, {}
scoring_function = partial(rastrigin_scorer, base_lag=base_lag, lag=lag) def scoring_fn(genotypes: jax.Arraykey: RNGKey) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: fitnesses, descriptors = scoring_function(genotypes) return fitnesses, descriptors, {}

Define the metrics function that will be used¶

In [ ]:
Copied!
reference_point = jnp.array([ -150, -150])

# how to compute metrics from a repertoire
metrics_function = partial(
    default_moqd_metrics,
    reference_point=reference_point
)
reference_point = jnp.array([ -150, -150]) # how to compute metrics from a repertoire metrics_function = partial( default_moqd_metrics, reference_point=reference_point )

Define the initial population and the emitter¶

In [ ]:
Copied!
# initial population
key = jax.random.key(42)
key, subkey = jax.random.split(key)
genotypes = jax.random.uniform(
    subkey, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32
)

# crossover function
crossover_function = partial(
    polynomial_crossover,
    proportion_var_to_change=proportion_var_to_change
)

# mutation function
mutation_function = partial(
    polynomial_mutation,
    eta=eta,
    minval=minval,
    maxval=maxval,
    proportion_to_mutate=proportion_to_mutate
)

# Define emitter
mixing_emitter = MixingEmitter(
    mutation_fn=mutation_function,
    variation_fn=crossover_function,
    variation_percentage=crossover_percentage,
    batch_size=batch_size
)
# initial population key = jax.random.key(42) key, subkey = jax.random.split(key) genotypes = jax.random.uniform( subkey, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32 ) # crossover function crossover_function = partial( polynomial_crossover, proportion_var_to_change=proportion_var_to_change ) # mutation function mutation_function = partial( polynomial_mutation, eta=eta, minval=minval, maxval=maxval, proportion_to_mutate=proportion_to_mutate ) # Define emitter mixing_emitter = MixingEmitter( mutation_fn=mutation_function, variation_fn=crossover_function, variation_percentage=crossover_percentage, batch_size=batch_size )

Compute the centroids¶

In [ ]:
Copied!
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=2,
    num_init_cvt_samples=20000,
    num_centroids=num_centroids,
    minval=minval,
    maxval=maxval,
    key=subkey,
)
key, subkey = jax.random.split(key) centroids = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=20000, num_centroids=num_centroids, minval=minval, maxval=maxval, key=subkey, )

Define a MOME instance¶

In [ ]:
Copied!
mome = MOME(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=metrics_function,
    pareto_front_max_length=pareto_front_max_length,
)
mome = MOME( scoring_function=scoring_fn, emitter=mixing_emitter, metrics_function=metrics_function, pareto_front_max_length=pareto_front_max_length, )

Init the algorithm¶

In [ ]:
Copied!
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = mome.init(
    genotypes,
    centroids,
    subkey
)
key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = mome.init( genotypes, centroids, subkey )

Run MOME iterations¶

In [ ]:
Copied!
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
    mome.scan_update,
    (repertoire, emitter_state, key),
    (),
    length=num_iterations,
)
# Run the algorithm (repertoire, emitter_state, key,), metrics = jax.lax.scan( mome.scan_update, (repertoire, emitter_state, key), (), length=num_iterations, )

Plot the results¶

In [ ]:
Copied!
moqd_scores = jnp.sum(metrics["moqd_score"], where=metrics["moqd_score"] != -jnp.inf, axis=-1)
moqd_scores = jnp.sum(metrics["moqd_score"], where=metrics["moqd_score"] != -jnp.inf, axis=-1)
In [ ]:
Copied!
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5))

steps = batch_size * jnp.arange(start=0, stop=num_iterations)
ax1.plot(steps, moqd_scores)
ax1.set_xlabel('Num steps')
ax1.set_ylabel('MOQD Score')

ax2.plot(steps, metrics["max_hypervolume"])
ax2.set_xlabel('Num steps')
ax2.set_ylabel('Max Hypervolume')

ax3.plot(steps, metrics["max_sum_scores"])
ax3.set_xlabel('Num steps')
ax3.set_ylabel('Max Sum Scores')

ax4.plot(steps, metrics["coverage"])
ax4.set_xlabel('Num steps')
ax4.set_ylabel('Coverage')
plt.show()
f, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 5)) steps = batch_size * jnp.arange(start=0, stop=num_iterations) ax1.plot(steps, moqd_scores) ax1.set_xlabel('Num steps') ax1.set_ylabel('MOQD Score') ax2.plot(steps, metrics["max_hypervolume"]) ax2.set_xlabel('Num steps') ax2.set_ylabel('Max Hypervolume') ax3.plot(steps, metrics["max_sum_scores"]) ax3.set_xlabel('Num steps') ax3.set_ylabel('Max Sum Scores') ax4.plot(steps, metrics["coverage"]) ax4.set_xlabel('Num steps') ax4.set_ylabel('Coverage') plt.show()
In [ ]:
Copied!
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)

# plot pareto fronts
axes = plot_mome_pareto_fronts(
    centroids,
    repertoire,
    minval=minval,
    maxval=maxval,
    color_style='spectral',
    axes=axes,
    with_global=True
)

# add map elites plot on last axe
plot_2d_map_elites_repertoire(
    centroids=centroids,
    repertoire_fitnesses=metrics["moqd_score"][-1],
    minval=minval,
    maxval=maxval,
    ax=axes[2]
)
plt.show()
fig, axes = plt.subplots(figsize=(18, 6), ncols=3) # plot pareto fronts axes = plot_mome_pareto_fronts( centroids, repertoire, minval=minval, maxval=maxval, color_style='spectral', axes=axes, with_global=True ) # add map elites plot on last axe plot_2d_map_elites_repertoire( centroids=centroids, repertoire_fitnesses=metrics["moqd_score"][-1], minval=minval, maxval=maxval, ax=axes[2] ) plt.show()
Previous Next

Built with MkDocs using a theme provided by Read the Docs.