Open In Colab

Optimizing with CMA-ES in Jax

This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with CMA-ES. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:

  • how to define the problem
  • how to create a CMA-ES optimizer
  • how to launch a certain number of optimizing steps
  • how to visualise the optimization process
import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
    import brax

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
    import chex

try:
    import jumanji
except:
    !pip install "jumanji==0.3.1"
    import jumanji

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

from qdax.core.cmaes import CMAES

Set the hyperparameters

#@title Hyperparameters
#@markdown ---
num_iterations = 1000 #@param {type:"integer"}
num_dimensions = 100 #@param {type:"integer"}
batch_size = 36 #@param {type:"integer"}
num_best = 18 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
optim_problem = "sphere" #@param["rastrigin", "sphere"]
#@markdown ---

Define the fitness function - choose rastrigin or sphere

def rastrigin_scoring(x: jnp.ndarray):
    first_term = 10 * x.shape[-1]
    second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
    return -(first_term + second_term)

def sphere_scoring(x: jnp.ndarray):
    return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)

if optim_problem == "sphere":
    fitness_fn = sphere_scoring
elif optim_problem == "rastrigin":
    fitness_fn = jax.vmap(rastrigin_scoring)
else:
    raise Exception("Invalid opt function name given")

Define a CMA-ES optimizer instance

cmaes = CMAES(
    population_size=batch_size,
    num_best=num_best,
    search_dim=num_dimensions,
    fitness_function=fitness_fn,
    mean_init=jnp.zeros((num_dimensions,)),
    init_sigma=sigma_g,
    delay_eigen_decomposition=True,
)

Init the CMA-ES optimizer state

state = cmaes.init()
random_key = jax.random.PRNGKey(0)

Run optimization iterations

%%time

means = [state.mean]
covs = [(state.sigma**2) * state.cov_matrix]

iteration_count = 0
for _ in range(num_iterations):
    iteration_count += 1

    # sample
    samples, random_key = cmaes.sample(state, random_key)

    # udpate
    state = cmaes.update(state, samples)

    # check stop condition
    stop_condition = cmaes.stop_condition(state)

    if stop_condition:
        break

    # store data for plotting
    means.append(state.mean)
    covs.append((state.sigma**2) * state.cov_matrix)

print("Num iterations before stop condition: ", iteration_count)

Check final fitnesses and distribution mean

# checking final fitness values
fitnesses = fitness_fn(samples)

print("Min fitness in the final population: ", jnp.min(fitnesses))
print("Mean fitness in the final population: ", jnp.mean(fitnesses))
print("Max fitness in the final population: ", jnp.max(fitnesses))

# checking mean of the final distribution
print("Final mean of the distribution: \n", means[-1])
# print("Final covariance matrix of the distribution: ", covs[-1])

Visualization of the optimization trajectory

fig, ax = plt.subplots(figsize=(12, 6))

# sample points to show fitness landscape
random_key, subkey = jax.random.split(random_key)
x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))
f_x = fitness_fn(x)

# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)
fig.colorbar(points)

# plot cma-es trajectory
traj_min = 0
traj_max = iteration_count
for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):
    ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
    ax.add_patch(ellipse)
    ax.plot(mean[0], mean[1], color='k', marker='x')

ax.set_title(f"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}")
plt.show()