QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
    • Installation
    • Init environment, policy, population params, init states of the env
    • Define the way the policy interacts with the env
    • Define the scoring function and the way metrics are computed
    • Define the emitter: QDPG Emitter
    • Instantiate and initialise the MAP Elites algorithm
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing with QDPG in JAX
  • Edit on QDax

Open In Colab

Optimizing with QDPG in JAX¶

This notebook shows how to use QDax to find diverse and performing controllers in MDPs with QDPG - Quality Diversity Policy Gradient in MAP-Elites. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create the QDPG emitter
  • how to create a Map-elites instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualize the results of the training process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import os

from IPython.display import clear_output
import functools
import time

import jax
import jax.numpy as jnp

from qdax.core.containers.archive import score_euclidean_novelty
from qdax.core.emitters.dpg_emitter import DiversityPGConfig
from qdax.core.emitters.qdpg_emitter import QDPGEmitter, QDPGEmitterConfig
from qdax.core.emitters.qpg_emitter import QualityPGConfig
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
import qdax.tasks.brax as environments
from qdax.tasks.brax.env_creators import scoring_function_brax_envs as scoring_function
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.utils.plotting import plot_map_elites_results

from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter
from qdax.utils.metrics import CSVLogger, default_qd_metrics
import os from IPython.display import clear_output import functools import time import jax import jax.numpy as jnp from qdax.core.containers.archive import score_euclidean_novelty from qdax.core.emitters.dpg_emitter import DiversityPGConfig from qdax.core.emitters.qdpg_emitter import QDPGEmitter, QDPGEmitterConfig from qdax.core.emitters.qpg_emitter import QualityPGConfig from qdax.core.map_elites import MAPElites from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids import qdax.tasks.brax as environments from qdax.tasks.brax.env_creators import scoring_function_brax_envs as scoring_function from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.core.emitters.mutation_operators import isoline_variation from qdax.utils.plotting import plot_map_elites_results from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter from qdax.utils.metrics import CSVLogger, default_qd_metrics
In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 250 #@param {type:"integer"}
num_iterations = 100 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (256, 256) #@param {type:"raw"}
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_descriptor = 0. #@param {type:"number"}
max_descriptor = 1.0 #@param {type:"number"}

# mutations size
quality_pg_batch_size = 34 #@param {type:"integer"}
diversity_pg_batch_size = 33 #@param {type:"integer"}
ga_batch_size = 33 #@param {type:"integer"}

env_batch_size = quality_pg_batch_size + diversity_pg_batch_size + ga_batch_size

# TD3 params
replay_buffer_size = 1000000 #@param {type:"number"}
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
critic_learning_rate = 3e-4 #@param {type:"number"}
greedy_learning_rate = 3e-4 #@param {type:"number"}
policy_learning_rate = 1e-3 #@param {type:"number"}
noise_clip = 0.5 #@param {type:"number"}
policy_noise = 0.2 #@param {type:"number"}
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
transitions_batch_size = 256 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
num_critic_training_steps = 300 #@param {type:"number"}
num_pg_training_steps = 100 #@param {type:"number"}
policy_delay = 2 #@param {type:"number"}

archive_acceptance_threshold = 0.1 #@param {type:"number"}
archive_max_size = 10000 #@param {type:"integer"}

num_nearest_neighb = 5 #@param {type:"integer"}
novelty_scaling_ratio = 1.0 #@param {type:"number"}
#@markdown --- 
#@title QD Training Definitions Fields #@markdown --- env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni'] episode_length = 250 #@param {type:"integer"} num_iterations = 100 #@param {type:"integer"} seed = 42 #@param {type:"integer"} policy_hidden_layer_sizes = (256, 256) #@param {type:"raw"} iso_sigma = 0.005 #@param {type:"number"} line_sigma = 0.05 #@param {type:"number"} num_init_cvt_samples = 50000 #@param {type:"integer"} num_centroids = 1024 #@param {type:"integer"} min_descriptor = 0. #@param {type:"number"} max_descriptor = 1.0 #@param {type:"number"} # mutations size quality_pg_batch_size = 34 #@param {type:"integer"} diversity_pg_batch_size = 33 #@param {type:"integer"} ga_batch_size = 33 #@param {type:"integer"} env_batch_size = quality_pg_batch_size + diversity_pg_batch_size + ga_batch_size # TD3 params replay_buffer_size = 1000000 #@param {type:"number"} critic_hidden_layer_size = (256, 256) #@param {type:"raw"} critic_learning_rate = 3e-4 #@param {type:"number"} greedy_learning_rate = 3e-4 #@param {type:"number"} policy_learning_rate = 1e-3 #@param {type:"number"} noise_clip = 0.5 #@param {type:"number"} policy_noise = 0.2 #@param {type:"number"} discount = 0.99 #@param {type:"number"} reward_scaling = 1.0 #@param {type:"number"} transitions_batch_size = 256 #@param {type:"number"} soft_tau_update = 0.005 #@param {type:"number"} num_critic_training_steps = 300 #@param {type:"number"} num_pg_training_steps = 100 #@param {type:"number"} policy_delay = 2 #@param {type:"number"} archive_acceptance_threshold = 0.1 #@param {type:"number"} archive_max_size = 10000 #@param {type:"integer"} num_nearest_neighb = 5 #@param {type:"integer"} novelty_scaling_ratio = 1.0 #@param {type:"number"} #@markdown ---

Init environment, policy, population params, init states of the env¶

Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype.

In [ ]:
Copied!
# Init environment
env = environments.create(env_name, episode_length=episode_length)
reset_fn = jax.jit(env.reset)

# Init a random key
key = jax.random.key(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=env_batch_size)
fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
# Init environment env = environments.create(env_name, episode_length=episode_length) reset_fn = jax.jit(env.reset) # Init a random key key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) policy_network = MLP( layer_sizes=policy_layer_sizes, kernel_init=jax.nn.initializers.lecun_uniform(), final_activation=jnp.tanh, ) # Init population of controllers key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

Define the way the policy interacts with the env¶

In [ ]:
Copied!
# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state,
    policy_params,
    key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    actions = policy_network.apply(policy_params, env_state.obs)

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, key, transition
# Define the function to play a step with the policy in the environment def play_step_fn( env_state, policy_params, key, ): """ Play an environment step and return the updated state and the transition. """ actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] next_state = env.step(env_state, actions) transition = QDTransition( obs=env_state.obs, next_obs=next_state.obs, rewards=next_state.reward, dones=next_state.done, actions=actions, truncations=next_state.info["truncation"], state_desc=state_desc, next_state_desc=next_state.info["state_descriptor"], ) return next_state, policy_params, key, transition

Define the scoring function and the way metrics are computed¶

The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual.

In [ ]:
Copied!
# Prepare the scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
scoring_fn = functools.partial(
    scoring_function,
    episode_length=episode_length,
    play_reset_fn=reset_fn,
    play_step_fn=play_step_fn,
    descriptor_extractor=descriptor_extraction_fn,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)
# Prepare the scoring function descriptor_extraction_fn = environments.descriptor_extractor[env_name] scoring_fn = functools.partial( scoring_function, episode_length=episode_length, play_reset_fn=reset_fn, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extraction_fn, ) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function metrics_function = functools.partial( default_qd_metrics, qd_offset=reward_offset * episode_length, )

Define the emitter: QDPG Emitter¶

The emitter is used to evolve the population at each mutation step. In this example, the emitter is the Quality Diversity Policy Gradient emitter, the one used in QDPG. It trains two critics with the transitions experienced in the environment. One with extrinsic reward from the environment, the other uses diversity rewards based on the novelty of the states encountered. The critic are then used to apply Policy Gradient updates to the policies evolved.

In [ ]:
Copied!
# Define the Quality PG emitter config
qpg_emitter_config = QualityPGConfig(
    env_batch_size=quality_pg_batch_size,
    batch_size=transitions_batch_size,
    critic_hidden_layer_size=critic_hidden_layer_size,
    critic_learning_rate=critic_learning_rate,
    actor_learning_rate=greedy_learning_rate,
    policy_learning_rate=policy_learning_rate,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    discount=discount,
    reward_scaling=reward_scaling,
    replay_buffer_size=replay_buffer_size,
    soft_tau_update=soft_tau_update,
    num_critic_training_steps=num_critic_training_steps,
    num_pg_training_steps=num_pg_training_steps,
    policy_delay=policy_delay,
)

# Define the Diversity PG emitter config
dpg_emitter_config = DiversityPGConfig(
    env_batch_size=diversity_pg_batch_size,
    batch_size=transitions_batch_size,
    critic_hidden_layer_size=critic_hidden_layer_size,
    critic_learning_rate=critic_learning_rate,
    actor_learning_rate=greedy_learning_rate,
    policy_learning_rate=policy_learning_rate,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    discount=discount,
    reward_scaling=reward_scaling,
    replay_buffer_size=replay_buffer_size,
    soft_tau_update=soft_tau_update,
    num_critic_training_steps=num_critic_training_steps,
    num_pg_training_steps=num_pg_training_steps,
    policy_delay=policy_delay,
    archive_acceptance_threshold=archive_acceptance_threshold,
    archive_max_size=archive_max_size,
)

# Define the QDPG Emitter config
qdpg_emitter_config = QDPGEmitterConfig(
    qpg_config=qpg_emitter_config,
    dpg_config=dpg_emitter_config,
    iso_sigma=iso_sigma,
    line_sigma=line_sigma,
    ga_batch_size=ga_batch_size,
)
# Define the Quality PG emitter config qpg_emitter_config = QualityPGConfig( env_batch_size=quality_pg_batch_size, batch_size=transitions_batch_size, critic_hidden_layer_size=critic_hidden_layer_size, critic_learning_rate=critic_learning_rate, actor_learning_rate=greedy_learning_rate, policy_learning_rate=policy_learning_rate, noise_clip=noise_clip, policy_noise=policy_noise, discount=discount, reward_scaling=reward_scaling, replay_buffer_size=replay_buffer_size, soft_tau_update=soft_tau_update, num_critic_training_steps=num_critic_training_steps, num_pg_training_steps=num_pg_training_steps, policy_delay=policy_delay, ) # Define the Diversity PG emitter config dpg_emitter_config = DiversityPGConfig( env_batch_size=diversity_pg_batch_size, batch_size=transitions_batch_size, critic_hidden_layer_size=critic_hidden_layer_size, critic_learning_rate=critic_learning_rate, actor_learning_rate=greedy_learning_rate, policy_learning_rate=policy_learning_rate, noise_clip=noise_clip, policy_noise=policy_noise, discount=discount, reward_scaling=reward_scaling, replay_buffer_size=replay_buffer_size, soft_tau_update=soft_tau_update, num_critic_training_steps=num_critic_training_steps, num_pg_training_steps=num_pg_training_steps, policy_delay=policy_delay, archive_acceptance_threshold=archive_acceptance_threshold, archive_max_size=archive_max_size, ) # Define the QDPG Emitter config qdpg_emitter_config = QDPGEmitterConfig( qpg_config=qpg_emitter_config, dpg_config=dpg_emitter_config, iso_sigma=iso_sigma, line_sigma=line_sigma, ga_batch_size=ga_batch_size, )
In [ ]:
Copied!
# Get the emitter
score_novelty = jax.jit(
    functools.partial(
        score_euclidean_novelty,
        num_nearest_neighb=num_nearest_neighb,
        scaling_ratio=novelty_scaling_ratio,
    )
)

# define the QDPG emitter
qdpg_emitter = QDPGEmitter(
    config=qdpg_emitter_config,
    policy_network=policy_network,
    env=env,
    score_novelty=score_novelty,
)
# Get the emitter score_novelty = jax.jit( functools.partial( score_euclidean_novelty, num_nearest_neighb=num_nearest_neighb, scaling_ratio=novelty_scaling_ratio, ) ) # define the QDPG emitter qdpg_emitter = QDPGEmitter( config=qdpg_emitter_config, policy_network=policy_network, env=env, score_novelty=score_novelty, )

Instantiate and initialise the MAP Elites algorithm¶

In [ ]:
Copied!
# Instantiate MAP Elites
map_elites = MAPElites(
    scoring_function=scoring_fn,
    emitter=qdpg_emitter,
    metrics_function=metrics_function,
)

# Compute the centroids
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
    num_descriptors=env.descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_descriptor,
    maxval=max_descriptor,
    key=subkey,
)

# compute initial repertoire
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = map_elites.init(
    init_variables, centroids, subkey
)
# Instantiate MAP Elites map_elites = MAPElites( scoring_function=scoring_fn, emitter=qdpg_emitter, metrics_function=metrics_function, ) # Compute the centroids key, subkey = jax.random.split(key) centroids = compute_cvt_centroids( num_descriptors=env.descriptor_length, num_init_cvt_samples=num_init_cvt_samples, num_centroids=num_centroids, minval=min_descriptor, maxval=max_descriptor, key=subkey, ) # compute initial repertoire key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = map_elites.init( init_variables, centroids, subkey )
In [ ]:
Copied!
log_period = 10
num_loops = num_iterations // log_period

# Initialize metrics
metrics = {key: jnp.array([]) for key in ["iteration", "qd_score", "coverage", "max_fitness", "time"]}

# Set up init metrics
init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics)
init_metrics["iteration"] = jnp.array([0], dtype=jnp.int32)
init_metrics["time"] = jnp.array([0.0])  # No time recorded for initialization

# Convert init_metrics to match the metrics dictionary structure
metrics = jax.tree.map(lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0), metrics, init_metrics)

# Initialize CSV logger
csv_logger = CSVLogger(
    "qdpg-logs.csv",
    header=list(metrics.keys())
)

# Main loop
map_elites_scan_update = map_elites.scan_update
for i in range(num_loops):
    start_time = time.time()
    (
        repertoire,
        emitter_state,
        key,
    ), current_metrics = jax.lax.scan(
        map_elites_scan_update,
        (repertoire, emitter_state, key),
        (),
        length=log_period,
    )
    timelapse = time.time() - start_time

    # Metrics
    current_metrics["iteration"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics)

    # Log
    csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
log_period = 10 num_loops = num_iterations // log_period # Initialize metrics metrics = {key: jnp.array([]) for key in ["iteration", "qd_score", "coverage", "max_fitness", "time"]} # Set up init metrics init_metrics = jax.tree.map(lambda x: jnp.array([x]) if x.shape == () else x, init_metrics) init_metrics["iteration"] = jnp.array([0], dtype=jnp.int32) init_metrics["time"] = jnp.array([0.0]) # No time recorded for initialization # Convert init_metrics to match the metrics dictionary structure metrics = jax.tree.map(lambda metric, init_metric: jnp.concatenate([metric, init_metric], axis=0), metrics, init_metrics) # Initialize CSV logger csv_logger = CSVLogger( "qdpg-logs.csv", header=list(metrics.keys()) ) # Main loop map_elites_scan_update = map_elites.scan_update for i in range(num_loops): start_time = time.time() ( repertoire, emitter_state, key, ), current_metrics = jax.lax.scan( map_elites_scan_update, (repertoire, emitter_state, key), (), length=log_period, ) timelapse = time.time() - start_time # Metrics current_metrics["iteration"] = jnp.arange(1+log_period*i, 1+log_period*(i+1), dtype=jnp.int32) current_metrics["time"] = jnp.repeat(timelapse, log_period) metrics = jax.tree.map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics, current_metrics) # Log csv_logger.log(jax.tree.map(lambda x: x[-1], metrics))
In [ ]:
Copied!
#@title Visualization

# Create the x-axis array
env_steps = metrics["iteration"]

%matplotlib inline
# Create the plots and the grid
fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)
#@title Visualization # Create the x-axis array env_steps = metrics["iteration"] %matplotlib inline # Create the plots and the grid fig, axes = plot_map_elites_results(env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=min_descriptor, max_descriptor=max_descriptor)
Previous Next

Built with MkDocs using a theme provided by Read the Docs.