QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
    • Visualize learnt behaviors
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • MAPElites PBT
  • Edit on QDax

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import functools
import time
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

import optax
from IPython.display import HTML
from tqdm import tqdm

import qdax.tasks.brax as environments
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig, PBTSacTrainingState
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig
from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results
import functools import time from typing import Optional, Tuple import jax import jax.numpy as jnp import optax from IPython.display import HTML from tqdm import tqdm import qdax.tasks.brax as environments from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig, PBTSacTrainingState from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids from qdax.core.distributed_map_elites import DistributedMAPElites from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey from qdax.utils.metrics import CSVLogger, default_qd_metrics from qdax.utils.plotting import plot_map_elites_results
In [ ]:
Copied!
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_platform_name", "cpu")
In [ ]:
Copied!
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
# Get devices (change gpu by tpu if needed) devices = jax.devices('gpu') num_devices = len(devices) print(f"Detected the following {num_devices} device(s): {devices}")
In [ ]:
Copied!
env_name = "anttrap"

seed = 0

# SAC config
batch_size = 256
episode_length = 1000
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256) 
policy_hidden_layer_size = (256, 256) 
fix_alpha = False
normalize_observations = False

# Emitter config
buffer_size = 100000
pg_population_size_per_device = 10
ga_population_size_per_device = 30
num_training_iterations = 10000
env_batch_size = 250
grad_updates_per_step = 1.0
iso_sigma = 0.005
line_sigma = 0.05

fraction_best_to_replace_from = 0.1
fraction_to_replace_from_best = 0.2
fraction_to_replace_from_samples = 0.4

eval_env_batch_size = 1

# MAP-Elites config
num_init_cvt_samples = 50000
num_centroids = 128
log_period = 1
num_loops = 10
env_name = "anttrap" seed = 0 # SAC config batch_size = 256 episode_length = 1000 tau = 0.005 alpha_init = 1.0 critic_hidden_layer_size = (256, 256) policy_hidden_layer_size = (256, 256) fix_alpha = False normalize_observations = False # Emitter config buffer_size = 100000 pg_population_size_per_device = 10 ga_population_size_per_device = 30 num_training_iterations = 10000 env_batch_size = 250 grad_updates_per_step = 1.0 iso_sigma = 0.005 line_sigma = 0.05 fraction_best_to_replace_from = 0.1 fraction_to_replace_from_best = 0.2 fraction_to_replace_from_samples = 0.4 eval_env_batch_size = 1 # MAP-Elites config num_init_cvt_samples = 50000 num_centroids = 128 log_period = 1 num_loops = 10
In [ ]:
Copied!
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * pg_population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=eval_env_batch_size,
    episode_length=episode_length,
    auto_reset=True,
)
# Initialize environments env = environments.create( env_name=env_name, batch_size=env_batch_size * pg_population_size_per_device, episode_length=episode_length, auto_reset=True, ) eval_env = environments.create( env_name=env_name, batch_size=eval_env_batch_size, episode_length=episode_length, auto_reset=True, )
In [ ]:
Copied!
min_descriptor, max_descriptor = env.descriptor_limits
min_descriptor, max_descriptor = env.descriptor_limits
In [ ]:
Copied!
key = jax.random.key(seed)

key, subkey_1, subkey_2 = jax.random.split(key, 3)
env_states = jax.jit(env.reset)(rng=subkey_1)
eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)
key = jax.random.key(seed) key, subkey_1, subkey_2 = jax.random.split(key, 3) env_states = jax.jit(env.reset)(rng=subkey_1) eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)
In [ ]:
Copied!
# get agent
config = PBTSacConfig(
    batch_size=batch_size,
    episode_length=episode_length,
    tau=tau,
    normalize_observations=normalize_observations,
    alpha_init=alpha_init,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
    fix_alpha=fix_alpha,
)

agent = PBTSAC(config=config, action_size=env.action_size)
# get agent config = PBTSacConfig( batch_size=batch_size, episode_length=episode_length, tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, critic_hidden_layer_size=critic_hidden_layer_size, policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, ) agent = PBTSAC(config=config, action_size=env.action_size)
In [ ]:
Copied!
# init emitter
emitter_config = PBTEmitterConfig(
    buffer_size=buffer_size,
    num_training_iterations=num_training_iterations // env_batch_size,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
    pg_population_size_per_device=pg_population_size_per_device,
    ga_population_size_per_device=ga_population_size_per_device,
    num_devices=num_devices,
    fraction_best_to_replace_from=fraction_best_to_replace_from,
    fraction_to_replace_from_best=fraction_to_replace_from_best,
    fraction_to_replace_from_samples=fraction_to_replace_from_samples,
    fraction_sort_exchange=0.1,
)
# init emitter emitter_config = PBTEmitterConfig( buffer_size=buffer_size, num_training_iterations=num_training_iterations // env_batch_size, env_batch_size=env_batch_size, grad_updates_per_step=grad_updates_per_step, pg_population_size_per_device=pg_population_size_per_device, ga_population_size_per_device=ga_population_size_per_device, num_devices=num_devices, fraction_best_to_replace_from=fraction_best_to_replace_from, fraction_to_replace_from_best=fraction_to_replace_from_best, fraction_to_replace_from_samples=fraction_to_replace_from_samples, fraction_sort_exchange=0.1, )
In [ ]:
Copied!
variation_fn = functools.partial(
    sac_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma
)
variation_fn = functools.partial( sac_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma )
In [ ]:
Copied!
emitter = PBTEmitter(
    pbt_agent=agent,
    config=emitter_config,
    env=env,
    variation_fn=variation_fn,
)
emitter = PBTEmitter( pbt_agent=agent, config=emitter_config, env=env, variation_fn=variation_fn, )
In [ ]:
Copied!
# get scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)


def scoring_function(genotypes, key):
    population_size = jax.tree.leaves(genotypes)[0].shape[0]
    first_states = jax.tree.map(
        lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states
    )
    first_states = jax.tree.map(
        lambda x: jnp.repeat(x, population_size, axis=0), first_states
    )
    population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)
    return population_returns, population_descriptors, {}
# get scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn) def scoring_function(genotypes, key): population_size = jax.tree.leaves(genotypes)[0].shape[0] first_states = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states ) first_states = jax.tree.map( lambda x: jnp.repeat(x, population_size, axis=0), first_states ) population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states) return population_returns, population_descriptors, {}
In [ ]:
Copied!
# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

# Get the MAP-Elites algorithm
map_elites = DistributedMAPElites(
    scoring_function=scoring_function,
    emitter=emitter,
    metrics_function=metrics_function,
)
# Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function metrics_function = functools.partial( default_qd_metrics, qd_offset=reward_offset * episode_length, ) # Get the MAP-Elites algorithm map_elites = DistributedMAPElites( scoring_function=scoring_function, emitter=emitter, metrics_function=metrics_function, )
In [ ]:
Copied!
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)
key, subkey = jax.random.split(key) centroids = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, key=subkey, )
In [ ]:
Copied!
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys)
In [ ]:
Copied!
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=pg_population_size_per_device + ga_population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)

training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)
# get the initial training states and replay buffers agent_init_fn = agent.get_init_fn( population_size=pg_population_size_per_device + ga_population_size_per_device, action_size=env.action_size, observation_size=env.observation_size, buffer_size=buffer_size, ) # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)
In [ ]:
Copied!
# empty optimizers states to avoid too heavy repertories
training_states = jax.pmap(
    jax.vmap(training_states.__class__.empty_optimizers_states),
    axis_name="p",
    devices=devices,
)(training_states)

# initialize map-elites
repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn(
    devices=devices, centroids=centroids
)(genotypes=training_states, key=keys)
# empty optimizers states to avoid too heavy repertories training_states = jax.pmap( jax.vmap(training_states.__class__.empty_optimizers_states), axis_name="p", devices=devices, )(training_states) # initialize map-elites repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn( devices=devices, centroids=centroids )(genotypes=training_states, key=keys)
In [ ]:
Copied!
update_fn = map_elites.get_distributed_update_fn(
    num_iterations=log_period, devices=devices
)
update_fn = map_elites.get_distributed_update_fn( num_iterations=log_period, devices=devices )
In [ ]:
Copied!
env_step_multiplier = (
    (pg_population_size_per_device + ga_population_size_per_device)
    * eval_env_batch_size
    * episode_length
    + num_training_iterations * pg_population_size_per_device
) * num_devices
env_step_multiplier = ( (pg_population_size_per_device + ga_population_size_per_device) * eval_env_batch_size * episode_length + num_training_iterations * pg_population_size_per_device ) * num_devices
In [ ]:
Copied!
all_metrics = {}

# Log init_metrics
for _key, _value in init_metrics.items():
    all_metrics[_key] = _value

for i in tqdm(range(num_loops // log_period), total=num_loops // log_period):
    start_time = time.time()

    repertoire, emitter_state, metrics = update_fn(
        repertoire, emitter_state, keys
    )
    metrics_cpu = jax.tree.map(
        lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics
    )
    timelapse = time.time() - start_time

    # log metrics
    for k, v in metrics_cpu.items():
        # take all values
        if k in all_metrics.keys():
            all_metrics[k] = jnp.concatenate([all_metrics[k], v])
        else:
            all_metrics[k] = v
all_metrics = {} # Log init_metrics for _key, _value in init_metrics.items(): all_metrics[_key] = _value for i in tqdm(range(num_loops // log_period), total=num_loops // log_period): start_time = time.time() repertoire, emitter_state, metrics = update_fn( repertoire, emitter_state, keys ) metrics_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics ) timelapse = time.time() - start_time # log metrics for k, v in metrics_cpu.items(): # take all values if k in all_metrics.keys(): all_metrics[k] = jnp.concatenate([all_metrics[k], v]) else: all_metrics[k] = v
In [ ]:
Copied!
# Create the performance evolution plots and visualize final grid
repertoire_cpu = jax.tree.map(
    lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], repertoire
)
num_loops_with_init = num_loops + 1
env_steps = (jnp.arange(num_loops_with_init) + 1) * env_step_multiplier

fig, axes = plot_map_elites_results(
    env_steps=env_steps,
    metrics=all_metrics,
    repertoire=repertoire_cpu,
    min_descriptor=min_descriptor,
    max_descriptor=max_descriptor,
)
# Create the performance evolution plots and visualize final grid repertoire_cpu = jax.tree.map( lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], repertoire ) num_loops_with_init = num_loops + 1 env_steps = (jnp.arange(num_loops_with_init) + 1) * env_step_multiplier fig, axes = plot_map_elites_results( env_steps=env_steps, metrics=all_metrics, repertoire=repertoire_cpu, min_descriptor=min_descriptor, max_descriptor=max_descriptor, )

Visualize learnt behaviors¶

In [ ]:
Copied!
# Evaluate best individual of the repertoire
best_idx = jnp.argmax(repertoire_cpu.fitnesses)
best_fitness = jnp.max(repertoire_cpu.fitnesses)
best_descriptor = repertoire_cpu.descriptors[best_idx]
# Evaluate best individual of the repertoire best_idx = jnp.argmax(repertoire_cpu.fitnesses) best_fitness = jnp.max(repertoire_cpu.fitnesses) best_descriptor = repertoire_cpu.descriptors[best_idx]
In [ ]:
Copied!
repertoire_cpu.descriptors.shape
repertoire_cpu.descriptors.shape
In [ ]:
Copied!
# Evaluate agent that goes the further on the y-axis
# best_idx = jnp.argmax(repertoire.descriptors[:, 0])
# best_fitness = repertoire.fitnesses[best_idx]
# best_descriptor = repertoire.descriptors[best_idx]
# Evaluate agent that goes the further on the y-axis # best_idx = jnp.argmax(repertoire.descriptors[:, 0]) # best_fitness = repertoire.fitnesses[best_idx] # best_descriptor = repertoire.descriptors[best_idx]
In [ ]:
Copied!
print(
    f"Fitness of the selected agent: {best_fitness:.2f}\n",
    f"Descriptor of the selected agent: {best_descriptor}\n",
    f"Index in the repertoire of this individual: {best_idx}\n",
)
print( f"Fitness of the selected agent: {best_fitness:.2f}\n", f"Descriptor of the selected agent: {best_descriptor}\n", f"Index in the repertoire of this individual: {best_idx}\n", )
In [ ]:
Copied!
env = environments.create(env_name, episode_length=episode_length)
env = environments.create(env_name, episode_length=episode_length)
In [ ]:
Copied!
play_step_fn = jax.pmap(
    functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
    axis_name="p",
    devices=devices[:1],
)
play_step_fn = jax.pmap( functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True), axis_name="p", devices=devices[:1], )
In [ ]:
Copied!
training_state = jax.tree.map(lambda x: x[best_idx], repertoire_cpu.genotypes)
training_state = jax.tree.map(lambda x: x[best_idx], repertoire_cpu.genotypes)
In [ ]:
Copied!
rollout = []
key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)

training_state, env_state = jax.tree.map(
    lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)

for _ in range(episode_length):

    rollout.append(env_state)
    env_state, training_state, _ = play_step_fn(env_state, training_state)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [] key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=subkey) training_state, env_state = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state) ) for _ in range(episode_length): rollout.append(env_state) env_state, training_state, _ = play_step_fn(env_state, training_state) print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
rollout = [
    jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
    for env_state in rollout
]
rollout = [ jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state) for env_state in rollout ]
In [ ]:
Copied!
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
In [ ]:
Copied!

Previous Next

Built with MkDocs using a theme provided by Read the Docs.