QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
    • Installation
    • Set the hyperparameters
    • Defines the scoring function: rastrigin
    • Define the metrics that will be used
    • Define the initial population, the emitter and the MAP Elites instance
    • Visualise results
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing with CMA-MEGA in JAX
  • Edit on QDax

Open In Colab

Optimizing with CMA-MEGA in JAX¶

This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with CMA-MEGA. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create a cma-mega emitter
  • how to create a Map-elites instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax
import jax.numpy as jnp

from qdax.core.map_elites import MAPElites
from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.utils.plotting import plot_map_elites_results

from typing import Dict
import jax import jax.numpy as jnp from qdax.core.map_elites import MAPElites from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire from qdax.utils.plotting import plot_map_elites_results from typing import Dict

Set the hyperparameters¶

Most hyperparameters are similar to those introduced in Differentiable Quality Diversity paper.

In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
num_iterations = 20000 #@param {type:"integer"}
num_dimensions = 1000 #@param {type:"integer"}
num_centroids = 10000 #@param {type:"integer"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
batch_size = 36 #@param {type:"integer"}
learning_rate = 1 #@param {type:"number"}
sigma_g = 3.16 #@param {type:"number"} # square root of 10, the value given in the paper
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
#@markdown ---
#@title QD Training Definitions Fields #@markdown --- num_iterations = 20000 #@param {type:"integer"} num_dimensions = 1000 #@param {type:"integer"} num_centroids = 10000 #@param {type:"integer"} minval = -5.12 #@param {type:"number"} maxval = 5.12 #@param {type:"number"} batch_size = 36 #@param {type:"integer"} learning_rate = 1 #@param {type:"number"} sigma_g = 3.16 #@param {type:"number"} # square root of 10, the value given in the paper minval = -5.12 #@param {type:"number"} maxval = 5.12 #@param {type:"number"} #@markdown ---

Defines the scoring function: rastrigin¶

As we are in the Differentiable QD setting, the scoring function does not only retrieve the fitness and descriptors, but also the gradients.

In [ ]:
Copied!
def rastrigin_scoring(x: jax.Array
    return -(10 * x.shape[-1] + jnp.sum((x+minval*0.4)**2 - 10 * jnp.cos(2 * jnp.pi * (x+minval*0.4))))

def clip(x: jax.Array
    return x*(x<=maxval)*(x>=+minval) + maxval/x*((x>maxval)+(x<+minval))

def _rastrigin_descriptor_1(x: jax.Array
    return jnp.mean(clip(x[:x.shape[-1]//2]))

def _rastrigin_descriptor_2(x: jax.Array
    return jnp.mean(clip(x[x.shape[-1]//2:]))

def rastrigin_descriptors(x: jax.Array
    return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)])

rastrigin_grad_scores = jax.grad(rastrigin_scoring)
def rastrigin_scoring(x: jax.Array return -(10 * x.shape[-1] + jnp.sum((x+minval*0.4)**2 - 10 * jnp.cos(2 * jnp.pi * (x+minval*0.4)))) def clip(x: jax.Array return x*(x<=maxval)*(x>=+minval) + maxval/x*((x>maxval)+(x<+minval)) def _rastrigin_descriptor_1(x: jax.Array return jnp.mean(clip(x[:x.shape[-1]//2])) def _rastrigin_descriptor_2(x: jax.Array return jnp.mean(clip(x[x.shape[-1]//2:])) def rastrigin_descriptors(x: jax.Array return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)]) rastrigin_grad_scores = jax.grad(rastrigin_scoring)
In [ ]:
Copied!
def scoring_function(x):
    scores, descriptors = rastrigin_scoring(x), rastrigin_descriptors(x)
    gradients = jnp.array([rastrigin_grad_scores(x), jax.grad(_rastrigin_descriptor_1)(x), jax.grad(_rastrigin_descriptor_2)(x)]).T
    gradients = jnp.nan_to_num(gradients)

    # Compute normalized gradients
    norm_gradients = jax.tree.map(
        lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),
        gradients,
    )
    grads = jax.tree.map(
        lambda x, y: x / y, gradients, norm_gradients
    )
    grads = jnp.nan_to_num(grads)
    extra_scores = {
        'gradients': gradients,
        'normalized_grads': grads
    }

    return scores, descriptors, extra_scores

def scoring_fn(x, key):
    fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)
    return fitnesses, descriptors, extra_scores
def scoring_function(x): scores, descriptors = rastrigin_scoring(x), rastrigin_descriptors(x) gradients = jnp.array([rastrigin_grad_scores(x), jax.grad(_rastrigin_descriptor_1)(x), jax.grad(_rastrigin_descriptor_2)(x)]).T gradients = jnp.nan_to_num(gradients) # Compute normalized gradients norm_gradients = jax.tree.map( lambda x: jnp.linalg.norm(x, axis=1, keepdims=True), gradients, ) grads = jax.tree.map( lambda x, y: x / y, gradients, norm_gradients ) grads = jnp.nan_to_num(grads) extra_scores = { 'gradients': gradients, 'normalized_grads': grads } return scores, descriptors, extra_scores def scoring_fn(x, key): fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) return fitnesses, descriptors, extra_scores

Define the metrics that will be used¶

In [ ]:
Copied!
worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12)
best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)


def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array

    # get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    adjusted_fitness = (
        (repertoire.fitnesses - worst_objective) / (best_objective - worst_objective)
    )
    qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) / num_centroids
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(adjusted_fitness)
    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}
worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array # get metrics grid_empty = repertoire.fitnesses == -jnp.inf adjusted_fitness = ( (repertoire.fitnesses - worst_objective) / (best_objective - worst_objective) ) qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) / num_centroids coverage = 100 * jnp.mean(1.0 - grid_empty) max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

Define the initial population, the emitter and the MAP Elites instance¶

The emitter is defined using the CMAMEGA emitter class. This emitter is given to a MAP-Elites instance to create an instance of the CMA-MEGA algorithm.

In [ ]:
Copied!
key = jax.random.key(0)
# no initial population - give all the same value as emitter init value
initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.

key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=2,
    num_init_cvt_samples=10000,
    num_centroids=num_centroids,
    minval=minval,
    maxval=maxval,
    key=subkey,
)

emitter = CMAMEGAEmitter(
    scoring_function=scoring_fn,
    batch_size=batch_size,
    learning_rate=learning_rate,
    num_descriptors=2,
    centroids=centroids,
    sigma_g=sigma_g,
)

map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=emitter,
    metrics_function=metrics_fn
)
key = jax.random.key(0) # no initial population - give all the same value as emitter init value initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0. key, subkey = jax.random.split(key) centroids = compute_cvt_centroids( num_descriptors=2, num_init_cvt_samples=10000, num_centroids=num_centroids, minval=minval, maxval=maxval, key=subkey, ) emitter = CMAMEGAEmitter( scoring_function=scoring_fn, batch_size=batch_size, learning_rate=learning_rate, num_descriptors=2, centroids=centroids, sigma_g=sigma_g, ) map_elites = MAPElites( scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn )
In [ ]:
Copied!
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = map_elites.init(initial_population, centroids, subkey)
key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = map_elites.init(initial_population, centroids, subkey)
In [ ]:
Copied!
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
    map_elites.scan_update,
    (repertoire, emitter_state, key),
    (),
    length=num_iterations,
)
(repertoire, emitter_state, key,), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, key), (), length=num_iterations, )
In [ ]:
Copied!
for k, v in metrics.items():
    print(f"{k} after {num_iterations * batch_size}: {v[-1]}")
for k, v in metrics.items(): print(f"{k} after {num_iterations * batch_size}: {v[-1]}")

Visualise results¶

In [ ]:
Copied!
#@title Visualization

# create the x-axis array
env_steps = jnp.arange(num_iterations) * batch_size

# create the plots and the grid
fig, axes = plot_map_elites_results(
    env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=minval, max_descriptor=maxval
)
#@title Visualization # create the x-axis array env_steps = jnp.arange(num_iterations) * batch_size # create the plots and the grid fig, axes = plot_map_elites_results( env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=minval, max_descriptor=maxval )
Previous Next

Built with MkDocs using a theme provided by Read the Docs.