Optimizing with CMA-MEGA in Jax¶
This notebook shows how to use QDax to find diverse and performing parameters on the Rastrigin problem with CMA-MEGA. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:
- how to define the problem
- how to create a cma-mega emitter
- how to create a Map-elites instance
- which functions must be defined before training
- how to launch a certain number of training steps
- how to visualise the optimization process
import jax
import jax.numpy as jnp
try:
import brax
except:
!pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
import brax
try:
import flax
except:
!pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
import flax
try:
import chex
except:
!pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
import chex
try:
import jumanji
except:
!pip install "jumanji==0.3.1"
import jumanji
try:
import qdax
except:
!pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
import qdax
from qdax.core.map_elites import MAPElites
from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire
from qdax.utils.plotting import plot_map_elites_results
from typing import Dict
Set the hyperparameters¶
Most hyperparameters are similar to those introduced in Differentiable Quality Diversity paper.
#@title QD Training Definitions Fields
#@markdown ---
num_iterations = 20000 #@param {type:"integer"}
num_dimensions = 1000 #@param {type:"integer"}
num_centroids = 10000 #@param {type:"integer"}
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
batch_size = 36 #@param {type:"integer"}
learning_rate = 1 #@param {type:"number"}
sigma_g = 3.16 #@param {type:"number"} # square root of 10, the value given in the paper
minval = -5.12 #@param {type:"number"}
maxval = 5.12 #@param {type:"number"}
#@markdown ---
Defines the scoring function: rastrigin¶
As we are in the Differentiable QD setting, the scoring function does not only retrieve the fitness and descriptors, but also the gradients.
def rastrigin_scoring(x: jnp.ndarray):
return -(10 * x.shape[-1] + jnp.sum((x+minval*0.4)**2 - 10 * jnp.cos(2 * jnp.pi * (x+minval*0.4))))
def clip(x: jnp.ndarray):
return x*(x<=maxval)*(x>=+minval) + maxval/x*((x>maxval)+(x<+minval))
def _rastrigin_descriptor_1(x: jnp.ndarray):
return jnp.mean(clip(x[:x.shape[-1]//2]))
def _rastrigin_descriptor_2(x: jnp.ndarray):
return jnp.mean(clip(x[x.shape[-1]//2:]))
def rastrigin_descriptors(x: jnp.ndarray):
return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)])
rastrigin_grad_scores = jax.grad(rastrigin_scoring)
def scoring_function(x):
scores, descriptors = rastrigin_scoring(x), rastrigin_descriptors(x)
gradients = jnp.array([rastrigin_grad_scores(x), jax.grad(_rastrigin_descriptor_1)(x), jax.grad(_rastrigin_descriptor_2)(x)]).T
gradients = jnp.nan_to_num(gradients)
# Compute normalized gradients
norm_gradients = jax.tree_util.tree_map(
lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),
gradients,
)
grads = jax.tree_util.tree_map(
lambda x, y: x / y, gradients, norm_gradients
)
grads = jnp.nan_to_num(grads)
extra_scores = {
'gradients': gradients,
'normalized_grads': grads
}
return scores, descriptors, extra_scores
def scoring_fn(x, random_key):
fitnesses, descriptors, extra_scores = jax.vmap(scoring_function)(x)
return fitnesses, descriptors, extra_scores, random_key
Define the metrics that will be used¶
worst_objective = rastrigin_scoring(-jnp.ones(num_dimensions) * 5.12)
best_objective = rastrigin_scoring(jnp.ones(num_dimensions) * 5.12 * 0.4)
def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:
# get metrics
grid_empty = repertoire.fitnesses == -jnp.inf
adjusted_fitness = (
(repertoire.fitnesses - worst_objective) / (best_objective - worst_objective)
)
qd_score = jnp.sum(adjusted_fitness, where=~grid_empty) / num_centroids
coverage = 100 * jnp.mean(1.0 - grid_empty)
max_fitness = jnp.max(adjusted_fitness)
return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}
Define the initial population, the emitter and the MAP Elites instance¶
The emitter is defined using the CMAMEGA emitter class. This emitter is given to a MAP-Elites instance to create an instance of the CMA-MEGA algorithm.
random_key = jax.random.PRNGKey(0)
# no initial population - give all the same value as emitter init value
initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.
centroids, random_key = compute_cvt_centroids(
num_descriptors=2,
num_init_cvt_samples=10000,
num_centroids=num_centroids,
minval=minval,
maxval=maxval,
random_key=random_key,
)
emitter = CMAMEGAEmitter(
scoring_function=scoring_fn,
batch_size=batch_size,
learning_rate=learning_rate,
num_descriptors=2,
centroids=centroids,
sigma_g=sigma_g,
)
map_elites = MAPElites(
scoring_function=scoring_fn,
emitter=emitter,
metrics_function=metrics_fn
)
repertoire, emitter_state, random_key = map_elites.init(initial_population, centroids, random_key)
%%time
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
map_elites.scan_update,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)
for k, v in metrics.items():
print(f"{k} after {num_iterations * batch_size}: {v[-1]}")
Visualise results¶
#@title Visualization
# create the x-axis array
env_steps = jnp.arange(num_iterations) * batch_size
# create the plots and the grid
fig, axes = plot_map_elites_results(
env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=minval, max_bd=maxval
)