QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing with CMA-ME in JAX
  • Edit on QDax

Open In Colab¶

Optimizing with CMA-ME in JAX¶

This notebook shows how to use QDax to find diverse and performing parameters on Rastrigin or Sphere problem with CMA-ME. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create a CMA-ME emitter
  • how to create a MAP-Elites instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import math

import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from qdax.core.map_elites import MAPElites
from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter
from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter
from qdax.core.emitters.cma_improvement_emitter import CMAImprovementEmitter
from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids, MapElitesRepertoire

from typing import Dict
import math import matplotlib as mpl import matplotlib.cm as cm import matplotlib.pyplot as plt import jax import jax.numpy as jnp from qdax.core.map_elites import MAPElites from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter from qdax.core.emitters.cma_improvement_emitter import CMAImprovementEmitter from qdax.core.emitters.cma_pool_emitter import CMAPoolEmitter from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids, MapElitesRepertoire from typing import Dict

Set the hyperparameters¶

Most hyperparameters are similar to those introduced in Differentiable Quality Diversity paper.

In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
num_iterations = 70000 #70000 #10000
num_dimensions = 100 #1000 #@param {type:"integer"}
grid_shape = (500, 500) # (500, 500)
batch_size = 36 #36 #@param {type:"integer"}
sigma_g = .5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
min_descriptor = -5.12 * 0.5 * num_dimensions #@param {type:"number"}
max_descriptor = 5.12 * 0.5 * num_dimensions #@param {type:"number"}
emitter_type = "imp" #@param["opt", "imp", "rnd"]
pool_size = 15 #@param {type:"integer"}
optim_problem = "rastrigin" #@param["rastrigin", "sphere"]
#@markdown ---
#@title QD Training Definitions Fields #@markdown --- num_iterations = 70000 #70000 #10000 num_dimensions = 100 #1000 #@param {type:"integer"} grid_shape = (500, 500) # (500, 500) batch_size = 36 #36 #@param {type:"integer"} sigma_g = .5 #@param {type:"number"} minval = -5.12 #@param {type:"number"} maxval = 5.12 #@param {type:"number"} min_descriptor = -5.12 * 0.5 * num_dimensions #@param {type:"number"} max_descriptor = 5.12 * 0.5 * num_dimensions #@param {type:"number"} emitter_type = "imp" #@param["opt", "imp", "rnd"] pool_size = 15 #@param {type:"integer"} optim_problem = "rastrigin" #@param["rastrigin", "sphere"] #@markdown ---

Defines the scoring function: rastrigin or sphere¶

In [ ]:
Copied!
def rastrigin_scoring(x: jax.Array):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)

def sphere_scoring(x: jax.Array):
    return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)

if optim_problem == "sphere":
    fitness_scoring = sphere_scoring
elif optim_problem == "rastrigin":
    fitness_scoring = rastrigin_scoring
else:
    raise Exception("Invalid opt function name given")

def clip(x: jax.Array):
    in_bound = (x <= maxval) * (x >= minval)
    return jnp.where(
        in_bound,
        x,
        (maxval / x)
    )

def _descriptor_1(x: jax.Array):
    return jnp.sum(clip(x[:x.shape[-1]//2]))

def _descriptor_2(x: jax.Array):
    return jnp.sum(clip(x[x.shape[-1]//2:]))

def _descriptors(x: jax.Array):
    return jnp.array([_descriptor_1(x), _descriptor_2(x)])
def rastrigin_scoring(x: jax.Array): first_term = 10 * x.shape[-1] second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4))) return -(first_term + second_term) def sphere_scoring(x: jax.Array): return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1) if optim_problem == "sphere": fitness_scoring = sphere_scoring elif optim_problem == "rastrigin": fitness_scoring = rastrigin_scoring else: raise Exception("Invalid opt function name given") def clip(x: jax.Array): in_bound = (x <= maxval) * (x >= minval) return jnp.where( in_bound, x, (maxval / x) ) def _descriptor_1(x: jax.Array): return jnp.sum(clip(x[:x.shape[-1]//2])) def _descriptor_2(x: jax.Array): return jnp.sum(clip(x[x.shape[-1]//2:])) def _descriptors(x: jax.Array): return jnp.array([_descriptor_1(x), _descriptor_2(x)])
In [ ]:
Copied!
def scoring_function(x):
    scores, descriptors = fitness_scoring(x), _descriptors(x)
    return scores, descriptors, {}

def scoring_fn(x, key):
    fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)
    return fitnesses, descriptors, extra_scores
def scoring_function(x): scores, descriptors = fitness_scoring(x), _descriptors(x) return scores, descriptors, {} def scoring_fn(x, key): fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x) return fitnesses, descriptors, extra_scores

Define the metrics that will be used¶

In [ ]:
Copied!
worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12)
best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)

num_centroids = math.prod(grid_shape)

def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array]:

    # get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    adjusted_fitness = (
        (repertoire.fitnesses - worst_objective) * 100 / (best_objective - worst_objective)
    )
    qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) # / num_centroids
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(adjusted_fitness)
    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}
worst_objective = fitness_scoring(-jnp.ones(num_dimensions) * 5.12) best_objective = fitness_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4) num_centroids = math.prod(grid_shape) def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jax.Array]: # get metrics grid_empty = repertoire.fitnesses == -jnp.inf adjusted_fitness = ( (repertoire.fitnesses - worst_objective) * 100 / (best_objective - worst_objective) ) qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) # / num_centroids coverage = 100 * jnp.mean(1.0 - grid_empty) max_fitness = jnp.max(adjusted_fitness) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

Define initial population, emitter and MAP Elites instance¶

The emitter is defined using the CMAME emitter class. This emitter is given to a MAP-Elites instance to create an instance of the CMA-ME algorithm.

In [ ]:
Copied!
key = jax.random.key(0)
# in CMA-ME settings (from the paper), there is no init population
# we multiply by zero to reproduce this setting
initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0.

centroids = compute_euclidean_centroids(
    grid_shape=grid_shape,
    minval=min_descriptor,
    maxval=max_descriptor,
)

emitter_kwargs = {
    "batch_size": batch_size,
    "genotype_dim": num_dimensions,
    "centroids": centroids,
    "sigma_g": sigma_g,
    "min_count": 1,
    "max_count": None,
}

if emitter_type == "opt":
    emitter = CMAOptimizingEmitter(**emitter_kwargs)
elif emitter_type == "imp":
    emitter = CMAImprovementEmitter(**emitter_kwargs)
elif emitter_type == "rnd":
    emitter = CMARndEmitter(**emitter_kwargs)
else:
    raise Exception("Invalid emitter type")

emitter = CMAPoolEmitter(
    num_states=pool_size,
    emitter=emitter
)

map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=emitter,
    metrics_function=metrics_fn
)
key = jax.random.key(0) # in CMA-ME settings (from the paper), there is no init population # we multiply by zero to reproduce this setting initial_population = jax.random.uniform(key, shape=(batch_size, num_dimensions)) * 0. centroids = compute_euclidean_centroids( grid_shape=grid_shape, minval=min_descriptor, maxval=max_descriptor, ) emitter_kwargs = { "batch_size": batch_size, "genotype_dim": num_dimensions, "centroids": centroids, "sigma_g": sigma_g, "min_count": 1, "max_count": None, } if emitter_type == "opt": emitter = CMAOptimizingEmitter(**emitter_kwargs) elif emitter_type == "imp": emitter = CMAImprovementEmitter(**emitter_kwargs) elif emitter_type == "rnd": emitter = CMARndEmitter(**emitter_kwargs) else: raise Exception("Invalid emitter type") emitter = CMAPoolEmitter( num_states=pool_size, emitter=emitter ) map_elites = MAPElites( scoring_function=scoring_fn, emitter=emitter, metrics_function=metrics_fn )

Init the repertoire and emitter state¶

In [ ]:
Copied!
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = map_elites.init(initial_population, centroids, subkey)
key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = map_elites.init(initial_population, centroids, subkey)

Run optimization/illumination process¶

In [ ]:
Copied!
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
    map_elites.scan_update,
    (repertoire, emitter_state, key),
    (),
    length=num_iterations,
)
(repertoire, emitter_state, key,), metrics = jax.lax.scan( map_elites.scan_update, (repertoire, emitter_state, key), (), length=num_iterations, )
In [ ]:
Copied!
for k, v in metrics.items():
    print(f"{k} after {num_iterations * batch_size}: {v[-1]}")
for k, v in metrics.items(): print(f"{k} after {num_iterations * batch_size}: {v[-1]}")

Plot results¶

Update the savefig variable to save your results locally.

In [ ]:
Copied!
env_steps = jnp.arange(num_iterations) * batch_size


# Customize matplotlib params
font_size = 16
params = {
    "axes.labelsize": font_size,
    "axes.titlesize": font_size,
    "legend.fontsize": font_size,
    "xtick.labelsize": font_size,
    "ytick.labelsize": font_size,
    "text.usetex": False,
    "axes.titlepad": 10,
}

mpl.rcParams.update(params)

# Visualize the training evolution and final repertoire
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40, 10))

# env_steps = jnp.arange(num_iterations) * episode_length * batch_size

axes[0].plot(env_steps, metrics["coverage"])
axes[0].set_xlabel("Environment steps")
axes[0].set_ylabel("Coverage in %")
axes[0].set_title("Coverage evolution during training")
axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable="box")

axes[1].plot(env_steps, metrics["max_fitness"])
axes[1].set_xlabel("Environment steps")
axes[1].set_ylabel("Maximum fitness")
axes[1].set_title("Maximum fitness evolution during training")
axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable="box")

axes[2].plot(env_steps, metrics["qd_score"])
axes[2].set_xlabel("Environment steps")
axes[2].set_ylabel("QD Score")
axes[2].set_title("QD Score evolution during training")
axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable="box")

# update this variable to save your results locally
savefig = False
if savefig:
    figname = "cma_me_" + optim_problem + "_" + str(num_dimensions) + "_" + emitter_type + ".png"
    print("Save figure in: ", figname)
    plt.savefig(figname)
env_steps = jnp.arange(num_iterations) * batch_size # Customize matplotlib params font_size = 16 params = { "axes.labelsize": font_size, "axes.titlesize": font_size, "legend.fontsize": font_size, "xtick.labelsize": font_size, "ytick.labelsize": font_size, "text.usetex": False, "axes.titlepad": 10, } mpl.rcParams.update(params) # Visualize the training evolution and final repertoire fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(40, 10)) # env_steps = jnp.arange(num_iterations) * episode_length * batch_size axes[0].plot(env_steps, metrics["coverage"]) axes[0].set_xlabel("Environment steps") axes[0].set_ylabel("Coverage in %") axes[0].set_title("Coverage evolution during training") axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable="box") axes[1].plot(env_steps, metrics["max_fitness"]) axes[1].set_xlabel("Environment steps") axes[1].set_ylabel("Maximum fitness") axes[1].set_title("Maximum fitness evolution during training") axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable="box") axes[2].plot(env_steps, metrics["qd_score"]) axes[2].set_xlabel("Environment steps") axes[2].set_ylabel("QD Score") axes[2].set_title("QD Score evolution during training") axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable="box") # update this variable to save your results locally savefig = False if savefig: figname = "cma_me_" + optim_problem + "_" + str(num_dimensions) + "_" + emitter_type + ".png" print("Save figure in: ", figname) plt.savefig(figname)
Previous Next

Built with MkDocs using a theme provided by Read the Docs.