QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
    • Installation
    • Set the hyperparameters
    • Define the fitness function - choose rastrigin or sphere
    • Define a CMA-ES optimizer instance
    • Init the CMA-ES optimizer state
    • Run optimization iterations
    • Check final fitnesses and distribution mean
    • Visualization of the optimization trajectory
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing with CMA-ES in JAX
  • Edit on QDax

Open In Colab

Optimizing with CMA-ES in JAX¶

This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with CMA-ES. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create a CMA-ES optimizer
  • how to launch a certain number of optimizing steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

from qdax.baselines.cmaes import CMAES
import jax import jax.numpy as jnp import matplotlib.pyplot as plt from matplotlib.patches import Ellipse from qdax.baselines.cmaes import CMAES

Set the hyperparameters¶

In [ ]:
Copied!
#@title Hyperparameters
#@markdown ---
num_iterations = 1000 #@param {type:"integer"}
num_dimensions = 100 #@param {type:"integer"}
batch_size = 36 #@param {type:"integer"}
num_best = 18 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
optim_problem = "sphere" #@param["rastrigin", "sphere"]
#@markdown ---
#@title Hyperparameters #@markdown --- num_iterations = 1000 #@param {type:"integer"} num_dimensions = 100 #@param {type:"integer"} batch_size = 36 #@param {type:"integer"} num_best = 18 #@param {type:"integer"} sigma_g = 0.5 # 0.5 #@param {type:"number"} minval = -5.12 #@param {type:"number"} optim_problem = "sphere" #@param["rastrigin", "sphere"] #@markdown ---

Define the fitness function - choose rastrigin or sphere¶

In [ ]:
Copied!
def rastrigin_scoring(x: jax.Array):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)

def sphere_scoring(x: jax.Array):
    return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)

if optim_problem == "sphere":
    fitness_fn = sphere_scoring
elif optim_problem == "rastrigin":
    fitness_fn = jax.vmap(rastrigin_scoring)
else:
    raise Exception("Invalid opt function name given")
def rastrigin_scoring(x: jax.Array): first_term = 10 * x.shape[-1] second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4))) return -(first_term + second_term) def sphere_scoring(x: jax.Array): return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1) if optim_problem == "sphere": fitness_fn = sphere_scoring elif optim_problem == "rastrigin": fitness_fn = jax.vmap(rastrigin_scoring) else: raise Exception("Invalid opt function name given")

Define a CMA-ES optimizer instance¶

In [ ]:
Copied!
cmaes = CMAES(
    population_size=batch_size,
    num_best=num_best,
    search_dim=num_dimensions,
    fitness_function=fitness_fn,
    mean_init=jnp.zeros((num_dimensions,)),
    init_sigma=sigma_g,
    delay_eigen_decomposition=True,
)
cmaes = CMAES( population_size=batch_size, num_best=num_best, search_dim=num_dimensions, fitness_function=fitness_fn, mean_init=jnp.zeros((num_dimensions,)), init_sigma=sigma_g, delay_eigen_decomposition=True, )

Init the CMA-ES optimizer state¶

In [ ]:
Copied!
state = cmaes.init()
key = jax.random.key(0)
state = cmaes.init() key = jax.random.key(0)

Run optimization iterations¶

In [ ]:
Copied!
means = [state.mean]
covs = [(state.sigma**2) * state.cov_matrix]

iteration_count = 0
sample_fn = jax.jit(cmaes.sample)
update_fn = jax.jit(cmaes.update)
stop_condition_fn = jax.jit(cmaes.stop_condition)
for _ in range(num_iterations):
    iteration_count += 1

    # sample
    key, subkey = jax.random.split(key)
    samples = sample_fn(state, subkey)

    # update
    state = update_fn(state, samples)

    # check stop condition
    stop_condition = stop_condition_fn(state)

    if stop_condition:
        break

    # store data for plotting
    means.append(state.mean)
    covs.append((state.sigma**2) * state.cov_matrix)

print("Num iterations before stop condition: ", iteration_count)
means = [state.mean] covs = [(state.sigma**2) * state.cov_matrix] iteration_count = 0 sample_fn = jax.jit(cmaes.sample) update_fn = jax.jit(cmaes.update) stop_condition_fn = jax.jit(cmaes.stop_condition) for _ in range(num_iterations): iteration_count += 1 # sample key, subkey = jax.random.split(key) samples = sample_fn(state, subkey) # update state = update_fn(state, samples) # check stop condition stop_condition = stop_condition_fn(state) if stop_condition: break # store data for plotting means.append(state.mean) covs.append((state.sigma**2) * state.cov_matrix) print("Num iterations before stop condition: ", iteration_count)

Check final fitnesses and distribution mean¶

In [ ]:
Copied!
# checking final fitness values
fitnesses = fitness_fn(samples)

print("Min fitness in the final population: ", jnp.min(fitnesses))
print("Mean fitness in the final population: ", jnp.mean(fitnesses))
print("Max fitness in the final population: ", jnp.max(fitnesses))

# checking mean of the final distribution
print("Final mean of the distribution: \n", means[-1])
# print("Final covariance matrix of the distribution: ", covs[-1])
# checking final fitness values fitnesses = fitness_fn(samples) print("Min fitness in the final population: ", jnp.min(fitnesses)) print("Mean fitness in the final population: ", jnp.mean(fitnesses)) print("Max fitness in the final population: ", jnp.max(fitnesses)) # checking mean of the final distribution print("Final mean of the distribution: \n", means[-1]) # print("Final covariance matrix of the distribution: ", covs[-1])

Visualization of the optimization trajectory¶

In [ ]:
Copied!
fig, ax = plt.subplots(figsize=(12, 6))

# sample points to show fitness landscape
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))
f_x = fitness_fn(x)

# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)
fig.colorbar(points)

# plot cma-es trajectory
traj_min = 0
traj_max = iteration_count
for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):
    ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
    ax.add_patch(ellipse)
    ax.plot(mean[0], mean[1], color='k', marker='x')

ax.set_title(f"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}")
plt.show()
fig, ax = plt.subplots(figsize=(12, 6)) # sample points to show fitness landscape key, subkey = jax.random.split(key) x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2)) f_x = fitness_fn(x) # plot fitness landscape points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1) fig.colorbar(points) # plot cma-es trajectory traj_min = 0 traj_max = iteration_count for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]): ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--') ax.add_patch(ellipse) ax.plot(mean[0], mean[1], color='k', marker='x') ax.set_title(f"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}") plt.show()
Previous Next

Built with MkDocs using a theme provided by Read the Docs.