QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
    • Installation
    • Init environment, policy, population params, init states of the env
    • Define the way the policy interacts with the env
    • Define the scoring function and the way metrics are computed
    • Define the emitter
    • Instantiate and initialise the ME-LS algorithm
    • Launch ME-LS iterations
    • Get the best individual of the repertoire
    • Play some steps in the environment
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Edit on QDax

Open In Colab

Optimizing Uncertain Domains with ME-LS in JAX¶

This notebook shows how to discover controllers that achieve consistent performance in MDP domains using the MAP-Elites Low-Spread algorithm. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create an emitter
  • how to create an ME-LS instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import os

from IPython.display import clear_output
import functools
import time

import jax
import jax.numpy as jnp

from qdax.core.mels import MELS
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.containers.mels_repertoire import MELSRepertoire
import qdax.tasks.brax as environments
from qdax.tasks.brax.env_creators import scoring_function_brax_envs
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.plotting import plot_map_elites_results

from qdax.utils.metrics import CSVLogger, default_qd_metrics

from jax.flatten_util import ravel_pytree

from IPython.display import HTML
import os from IPython.display import clear_output import functools import time import jax import jax.numpy as jnp from qdax.core.mels import MELS from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.containers.mels_repertoire import MELSRepertoire import qdax.tasks.brax as environments from qdax.tasks.brax.env_creators import scoring_function_brax_envs from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.utils.plotting import plot_map_elites_results from qdax.utils.metrics import CSVLogger, default_qd_metrics from jax.flatten_util import ravel_pytree from IPython.display import HTML
In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
batch_size = 100 #@param {type:"number"}
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
num_samples = 5 #@param {type:"number"}
episode_length = 100 #@param {type:"integer"}
num_iterations = 1000 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_descriptor = 0. #@param {type:"number"}
max_descriptor = 1.0 #@param {type:"number"}
#@markdown ---
#@title QD Training Definitions Fields #@markdown --- batch_size = 100 #@param {type:"number"} env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni'] num_samples = 5 #@param {type:"number"} episode_length = 100 #@param {type:"integer"} num_iterations = 1000 #@param {type:"integer"} seed = 42 #@param {type:"integer"} policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"} iso_sigma = 0.005 #@param {type:"number"} line_sigma = 0.05 #@param {type:"number"} num_init_cvt_samples = 50000 #@param {type:"integer"} num_centroids = 1024 #@param {type:"integer"} min_descriptor = 0. #@param {type:"number"} max_descriptor = 1.0 #@param {type:"number"} #@markdown ---

Init environment, policy, population params, init states of the env¶

Define the environment in which the policies will be trained. In this notebook, we consider the problem where each controller is evaluated num_samples times, each time in a different environment.

In [ ]:
Copied!
# Init environment
env = environments.create(env_name, episode_length=episode_length)

# Init a random key
key = jax.random.key(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers. There are batch_size controllers, and each
# controller will be evaluated num_samples times.
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
# Init environment env = environments.create(env_name, episode_length=episode_length) # Init a random key key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) policy_network = MLP( layer_sizes=policy_layer_sizes, kernel_init=jax.nn.initializers.lecun_uniform(), final_activation=jnp.tanh, ) # Init population of controllers. There are batch_size controllers, and each # controller will be evaluated num_samples times. key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

Define the way the policy interacts with the env¶

Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data. This is identical to the function in the MAP-Elites tutorial.

In [ ]:
Copied!
# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state,
    policy_params,
    key,
):
    """Play an environment step and return the updated state and the
    transition."""

    actions = policy_network.apply(policy_params, env_state.obs)

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, key, transition
# Define the function to play a step with the policy in the environment def play_step_fn( env_state, policy_params, key, ): """Play an environment step and return the updated state and the transition.""" actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] next_state = env.step(env_state, actions) transition = QDTransition( obs=env_state.obs, next_obs=next_state.obs, rewards=next_state.reward, dones=next_state.done, actions=actions, truncations=next_state.info["truncation"], state_desc=state_desc, next_state_desc=next_state.info["state_descriptor"], ) return next_state, policy_params, key, transition

Define the scoring function and the way metrics are computed¶

The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual. Note that while the MAP-Elites tutorial uses scoring_function_brax_envs as the basis for the scoring function, we use scoring_function_brax_envs. The difference is that scoring_function_brax_envs generates initial states randomly instead of taking in a fixed set of initial states. This is necessary since we are evaluating each controller across sampled initial states. If the initial states were kept the same for all evaluations, there would be no stochasticity.

In [ ]:
Copied!
# Prepare the scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function_brax_envs,
    episode_length=episode_length,
    play_reset_fn=env.reset,
    play_step_fn=play_step_fn,
    descriptor_extractor=descriptor_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_fn = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)
# Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function_brax_envs, episode_length=episode_length, play_reset_fn=env.reset, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extraction_fn, ) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function metrics_fn = functools.partial( default_qd_metrics, qd_offset=reward_offset * episode_length, )

Define the emitter¶

The emitter is used to evolve the population at each mutation step.

In [ ]:
Copied!
# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None,
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size
)
# Define emitter variation_fn = functools.partial( isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma ) mixing_emitter = MixingEmitter( mutation_fn=None, variation_fn=variation_fn, variation_percentage=1.0, batch_size=batch_size )

Instantiate and initialise the ME-LS algorithm¶

In [ ]:
Copied!
# Instantiate ME-LS.
mels = MELS(
    scoring_function=scoring_fn,
    emitter=mixing_emitter,
    metrics_function=metrics_fn,
    num_samples=num_samples,
)

# Compute the centroids
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)

# Compute initial repertoire and emitter state
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = mels.init(init_variables, centroids, subkey)
# Instantiate ME-LS. mels = MELS( scoring_function=scoring_fn, emitter=mixing_emitter, metrics_function=metrics_fn, num_samples=num_samples, ) # Compute the centroids key, subkey = jax.random.split(key) centroids = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, key=subkey, ) # Compute initial repertoire and emitter state key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = mels.init(init_variables, centroids, subkey)

Launch ME-LS iterations¶

In [ ]:
Copied!
log_period = 10
num_loops = num_iterations // log_period

# Initialize metrics
metrics = {key: jnp.array([]) for key in ["iteration", "qd_score", "coverage", "max_fitness", "time"]}

# Set up init metrics
init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics)
init_metrics["iteration"] = jnp.array([0], dtype=jnp.int32)
init_metrics["time"] = jnp.array([0.0])  # No time recorded for initialization

# Convert init_metrics to match the metrics dictionary structure
metrics = jax.tree.map(lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0), metrics, init_metrics)

# Initialize CSV logger
csv_logger = CSVLogger(
    "mels-logs.csv",
    header=list(metrics.keys())
)

# Main loop
mels_scan_update = mels.scan_update
for i in range(num_loops):
    start_time = time.time()
    (
        repertoire,
        emitter_state,
        key,
    ), current_metrics = jax.lax.scan(
        mels_scan_update,
        (repertoire, emitter_state, key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time

    # Metrics
    current_metrics["iteration"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)

    # Log
    csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
log_period = 10 num_loops = num_iterations // log_period # Initialize metrics metrics = {key: jnp.array([]) for key in ["iteration", "qd_score", "coverage", "max_fitness", "time"]} # Set up init metrics init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics) init_metrics["iteration"] = jnp.array([0], dtype=jnp.int32) init_metrics["time"] = jnp.array([0.0]) # No time recorded for initialization # Convert init_metrics to match the metrics dictionary structure metrics = jax.tree.map(lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0), metrics, init_metrics) # Initialize CSV logger csv_logger = CSVLogger( "mels-logs.csv", header=list(metrics.keys()) ) # Main loop mels_scan_update = mels.scan_update for i in range(num_loops): start_time = time.time() ( repertoire, emitter_state, key, ), current_metrics = jax.lax.scan( mels_scan_update, (repertoire, emitter_state, key), (), length=log_period, ) timelapse = time.time() - start_time # Metrics current_metrics["iteration"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32) current_metrics["time"] = jnp.repeat(timelapse, log_period) metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics) # Log csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
In [ ]:
Copied!
#@title Visualization

# Create the x-axis array
env_steps = metrics["iteration"]

%matplotlib inline
# Create the plots and the grid
fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)
#@title Visualization # Create the x-axis array env_steps = metrics["iteration"] %matplotlib inline # Create the plots and the grid fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)

Get the best individual of the repertoire¶

Note that in ME-LS, the individual's cell is computed by finding its most frequent archive cell among its num_samples descriptors. Thus, the descriptor associated with each individual in the archive is not its mean descriptor. Rather, we set the descriptor in the archive to be the centroid of the individual's most frequent archive cell.

In [ ]:
Copied!
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_descriptor = repertoire.descriptors[best_idx]
best_spread = repertoire.spreads[best_idx]
best_idx = jnp.argmax(repertoire.fitnesses) best_fitness = jnp.max(repertoire.fitnesses) best_descriptor = repertoire.descriptors[best_idx] best_spread = repertoire.spreads[best_idx]
In [ ]:
Copied!
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n"
    f"Descriptor of the best individual in the repertoire: {best_descriptor}\n"
    f"Spread of the best individual in the repertoire: {best_spread}\n"
    f"Index in the repertoire of this individual: {best_idx}\n"
)
print( f"Best fitness in the repertoire: {best_fitness:.2f}\n" f"Descriptor of the best individual in the repertoire: {best_descriptor}\n" f"Spread of the best individual in the repertoire: {best_spread}\n" f"Index in the repertoire of this individual: {best_idx}\n" )
In [ ]:
Copied!
# select the parameters of the best individual
my_params = jax.tree.map(
    lambda x: x[best_idx],
    repertoire.genotypes
)
# select the parameters of the best individual my_params = jax.tree.map( lambda x: x[best_idx], repertoire.genotypes )

Play some steps in the environment¶

In [ ]:
Copied!
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(policy_network.apply)
jit_env_reset = jax.jit(env.reset) jit_env_step = jax.jit(env.step) jit_inference_fn = jax.jit(policy_network.apply)
In [ ]:
Copied!
rollout = []
key, subkey = jax.random.split(key)
state = jit_env_reset(rng=subkey)
while not state.done:
    rollout.append(state)
    action = jit_inference_fn(my_params, state.obs)
    state = jit_env_step(state, action)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [] key, subkey = jax.random.split(key) state = jit_env_reset(rng=subkey) while not state.done: rollout.append(state) action = jit_inference_fn(my_params, state.obs) state = jit_env_step(state, action) print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))
HTML(html.render(env.sys, [s.qp for s in rollout[:500]]))
Previous Next

Built with MkDocs using a theme provided by Read the Docs.