Optimizing with CMA-ES in JAX¶
This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with CMA-ES. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:
- how to define the problem
- how to create a CMA-ES optimizer
- how to launch a certain number of optimizing steps
- how to visualise the optimization process
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install QDax from PyPI:
In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from qdax.baselines.cmaes import CMAES
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from qdax.baselines.cmaes import CMAES
Set the hyperparameters¶
In [ ]:
Copied!
#@title Hyperparameters
#@markdown ---
num_iterations = 1000 #@param {type:"integer"}
num_dimensions = 100 #@param {type:"integer"}
batch_size = 36 #@param {type:"integer"}
num_best = 18 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
optim_problem = "sphere" #@param["rastrigin", "sphere"]
#@markdown ---
#@title Hyperparameters
#@markdown ---
num_iterations = 1000 #@param {type:"integer"}
num_dimensions = 100 #@param {type:"integer"}
batch_size = 36 #@param {type:"integer"}
num_best = 18 #@param {type:"integer"}
sigma_g = 0.5 # 0.5 #@param {type:"number"}
minval = -5.12 #@param {type:"number"}
optim_problem = "sphere" #@param["rastrigin", "sphere"]
#@markdown ---
Define the fitness function - choose rastrigin or sphere¶
In [ ]:
Copied!
def rastrigin_scoring(x: jax.Array):
first_term = 10 * x.shape[-1]
second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
return -(first_term + second_term)
def sphere_scoring(x: jax.Array):
return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)
if optim_problem == "sphere":
fitness_fn = sphere_scoring
elif optim_problem == "rastrigin":
fitness_fn = jax.vmap(rastrigin_scoring)
else:
raise Exception("Invalid opt function name given")
def rastrigin_scoring(x: jax.Array):
first_term = 10 * x.shape[-1]
second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))
return -(first_term + second_term)
def sphere_scoring(x: jax.Array):
return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)
if optim_problem == "sphere":
fitness_fn = sphere_scoring
elif optim_problem == "rastrigin":
fitness_fn = jax.vmap(rastrigin_scoring)
else:
raise Exception("Invalid opt function name given")
Define a CMA-ES optimizer instance¶
In [ ]:
Copied!
cmaes = CMAES(
population_size=batch_size,
num_best=num_best,
search_dim=num_dimensions,
fitness_function=fitness_fn,
mean_init=jnp.zeros((num_dimensions,)),
init_sigma=sigma_g,
delay_eigen_decomposition=True,
)
cmaes = CMAES(
population_size=batch_size,
num_best=num_best,
search_dim=num_dimensions,
fitness_function=fitness_fn,
mean_init=jnp.zeros((num_dimensions,)),
init_sigma=sigma_g,
delay_eigen_decomposition=True,
)
Init the CMA-ES optimizer state¶
In [ ]:
Copied!
state = cmaes.init()
key = jax.random.key(0)
state = cmaes.init()
key = jax.random.key(0)
Run optimization iterations¶
In [ ]:
Copied!
means = [state.mean]
covs = [(state.sigma**2) * state.cov_matrix]
iteration_count = 0
sample_fn = jax.jit(cmaes.sample)
update_fn = jax.jit(cmaes.update)
stop_condition_fn = jax.jit(cmaes.stop_condition)
for _ in range(num_iterations):
iteration_count += 1
# sample
key, subkey = jax.random.split(key)
samples = sample_fn(state, subkey)
# update
state = update_fn(state, samples)
# check stop condition
stop_condition = stop_condition_fn(state)
if stop_condition:
break
# store data for plotting
means.append(state.mean)
covs.append((state.sigma**2) * state.cov_matrix)
print("Num iterations before stop condition: ", iteration_count)
means = [state.mean]
covs = [(state.sigma**2) * state.cov_matrix]
iteration_count = 0
sample_fn = jax.jit(cmaes.sample)
update_fn = jax.jit(cmaes.update)
stop_condition_fn = jax.jit(cmaes.stop_condition)
for _ in range(num_iterations):
iteration_count += 1
# sample
key, subkey = jax.random.split(key)
samples = sample_fn(state, subkey)
# update
state = update_fn(state, samples)
# check stop condition
stop_condition = stop_condition_fn(state)
if stop_condition:
break
# store data for plotting
means.append(state.mean)
covs.append((state.sigma**2) * state.cov_matrix)
print("Num iterations before stop condition: ", iteration_count)
Check final fitnesses and distribution mean¶
In [ ]:
Copied!
# checking final fitness values
fitnesses = fitness_fn(samples)
print("Min fitness in the final population: ", jnp.min(fitnesses))
print("Mean fitness in the final population: ", jnp.mean(fitnesses))
print("Max fitness in the final population: ", jnp.max(fitnesses))
# checking mean of the final distribution
print("Final mean of the distribution: \n", means[-1])
# print("Final covariance matrix of the distribution: ", covs[-1])
# checking final fitness values
fitnesses = fitness_fn(samples)
print("Min fitness in the final population: ", jnp.min(fitnesses))
print("Mean fitness in the final population: ", jnp.mean(fitnesses))
print("Max fitness in the final population: ", jnp.max(fitnesses))
# checking mean of the final distribution
print("Final mean of the distribution: \n", means[-1])
# print("Final covariance matrix of the distribution: ", covs[-1])
Visualization of the optimization trajectory¶
In [ ]:
Copied!
fig, ax = plt.subplots(figsize=(12, 6))
# sample points to show fitness landscape
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))
f_x = fitness_fn(x)
# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)
fig.colorbar(points)
# plot cma-es trajectory
traj_min = 0
traj_max = iteration_count
for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):
ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
ax.add_patch(ellipse)
ax.plot(mean[0], mean[1], color='k', marker='x')
ax.set_title(f"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}")
plt.show()
fig, ax = plt.subplots(figsize=(12, 6))
# sample points to show fitness landscape
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))
f_x = fitness_fn(x)
# plot fitness landscape
points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)
fig.colorbar(points)
# plot cma-es trajectory
traj_min = 0
traj_max = iteration_count
for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):
ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')
ax.add_patch(ellipse)
ax.plot(mean[0], mean[1], color='k', marker='x')
ax.set_title(f"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}")
plt.show()