Open In Colab

Optimizing Uncertain Domains with ME-LS in JAX

This notebook shows how to discover controllers that achieve consistent performance in MDP domains using the MAP-Elites Low-Spread algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create an emitter
  • how to create an ME-LS instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process
  • how to save/load a repertoire
#@title Installs and Imports
!pip install ipympl |tail -n 1
# %matplotlib widget
# from google.colab import output
# output.enable_custom_widget_manager()

import os

from IPython.display import clear_output
import functools
import time

import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
    import brax

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
    import chex

try:
    import jumanji
except:
    !pip install "jumanji==0.3.1"
    import jumanji

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax


from qdax.core.mels import MELS
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.containers.mels_repertoire import MELSRepertoire
from qdax import environments
from qdax.tasks.brax_envs import reset_based_scoring_function_brax_envs
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_map_elites_results

from qdax.utils.metrics import CSVLogger, default_qd_metrics

from jax.flatten_util import ravel_pytree

from IPython.display import HTML
from brax.v1.io import html



if "COLAB_TPU_ADDR" in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()


clear_output()
#@title QD Training Definitions Fields
#@markdown ---
batch_size = 100 #@param {type:"number"}
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
num_samples = 5 #@param {type:"number"}
episode_length = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_bd = 0. #@param {type:"number"}
max_bd = 1.0 #@param {type:"number"}
#@markdown ---

Init environment, policy, population params, init states of the env

Define the environment in which the policies will be trained. In this notebook, we consider the problem where each controller is evaluated num_samples times, each time in a different environment.

# Init environment
env = environments.create(env_name, episode_length=episode_length)

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers. There are batch_size controllers, and each
# controller will be evaluated num_samples times.
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

Define the way the policy interacts with the env

Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data. This is identical to the function in the MAP-Elites tutorial.

# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state,
    policy_params,
    random_key,
):
    """Play an environment step and return the updated state and the
    transition."""

    actions = policy_network.apply(policy_params, env_state.obs)

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, random_key, transition

Define the scoring function and the way metrics are computed

The scoring function is used in the evaluation step to determine the fitness and behavior descriptor of each individual. Note that while the MAP-Elites tutorial uses scoring_function_brax_envs as the basis for the scoring function, we use reset_based_scoring_function_brax_envs. The difference is that reset_based_scoring_function_brax_envs generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity in the behavior.

# Prepare the scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
scoring_fn = functools.partial(
    reset_based_scoring_function_brax_envs,
    episode_length=episode_length,
    play_reset_fn=env.reset,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_fn = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

Define the emitter

The emitter is used to evolve the population at each mutation step.

# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None,
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size
)

Instantiate and initialise the ME-LS algorithm

# Instantiate ME-LS.
mels = MELS(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=metrics_fn,
    num_samples=num_samples,
)

# Compute the centroids
centroids, random_key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_bd,
    maxval=max_bd,
    random_key=random_key,
)

# Compute initial repertoire and emitter state
repertoire, emitter_state, random_key = mels.init(init_variables, centroids, random_key)

Launch ME-LS iterations

log_period = 10
num_loops = int(num_iterations / log_period)

csv_logger = CSVLogger(
    "mapelites-logs.csv",
    header=["loop", "iteration", "qd_score", "max_fitness", "coverage", "time"]
)
all_metrics = {}

# main loop
mels_scan_update = mels.scan_update
for i in range(num_loops):
    start_time = time.time()
    # main iterations
    (repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
        mels_scan_update,
        (repertoire, emitter_state, random_key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time

    # log metrics
    logged_metrics = {"time": timelapse, "loop": 1+i, "iteration": 1 + i*log_period}
    for key, value in metrics.items():
        # take last value
        logged_metrics[key] = value[-1]

        # take all values
        if key in all_metrics.keys():
            all_metrics[key] = jnp.concatenate([all_metrics[key], value])
        else:
            all_metrics[key] = value

    csv_logger.log(logged_metrics)
#@title Visualization

# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size

# create the plots and the grid
fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=all_metrics, repertoire=repertoire, min_bd=min_bd, max_bd=max_bd)

How to save/load a repertoire

The following cells show how to save or load a repertoire of individuals and add a few lines to visualise the best performing individual in a simulation.

Load the final repertoire

repertoire_path = "./last_repertoire/"
os.makedirs(repertoire_path, exist_ok=True)
repertoire.save(path=repertoire_path)

Build the reconstruction function

# Init population of policies
random_key, subkey = jax.random.split(random_key)
fake_batch = jnp.zeros(shape=(env.observation_size,))
fake_params = policy_network.init(subkey, fake_batch)

_, reconstruction_fn = ravel_pytree(fake_params)

Use the reconstruction function to load and re-create the repertoire

repertoire = MELSRepertoire.load(reconstruction_fn=reconstruction_fn, path=repertoire_path)

Get the best individual of the repertoire

Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its num_samples behavior descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell.

best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_bd = repertoire.descriptors[best_idx]
best_spread = repertoire.spreads[best_idx]
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n"
    f"Behavior descriptor of the best individual in the repertoire: {best_bd}\n"
    f"Spread of the best individual in the repertoire: {best_spread}\n"
    f"Index in the repertoire of this individual: {best_idx}\n"
)
my_params = jax.tree_util.tree_map(
    lambda x: x[best_idx],
    repertoire.genotypes
)

Play some steps in the environment

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(policy_network.apply)
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
while not state.done:
    rollout.append(state)
    action = jit_inference_fn(my_params, state.obs)
    state = jit_env_step(state, action)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))