QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
    • Visualize learnt behaviors
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • PBT
  • Edit on QDax

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import functools

import jax
import jax.numpy as jnp

from IPython.display import HTML
from tqdm import tqdm

import qdax.tasks.brax as environments
from qdax.baselines.pbt import PBT
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig
import functools import jax import jax.numpy as jnp from IPython.display import HTML from tqdm import tqdm import qdax.tasks.brax as environments from qdax.baselines.pbt import PBT from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig
In [ ]:
Copied!
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_platform_name", "cpu")
In [ ]:
Copied!
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
# Get devices (change gpu by tpu if needed) devices = jax.devices('gpu') num_devices = len(devices) print(f"Detected the following {num_devices} device(s): {devices}")
In [ ]:
Copied!
env_name = "humanoidtrap"
seed = 0
env_batch_size = 250
population_size_per_device = 10
population_size = population_size_per_device * num_devices
num_steps = 10000
buffer_size = 100000

# PBT Config
num_best_to_replace_from = 1
num_worse_to_replace = 1

# SAC config
batch_size = 256
episode_length = 1000
grad_updates_per_step = 1.0
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256) 
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False

num_loops = 10
print_freq = 1
env_name = "humanoidtrap" seed = 0 env_batch_size = 250 population_size_per_device = 10 population_size = population_size_per_device * num_devices num_steps = 10000 buffer_size = 100000 # PBT Config num_best_to_replace_from = 1 num_worse_to_replace = 1 # SAC config batch_size = 256 episode_length = 1000 grad_updates_per_step = 1.0 tau = 0.005 alpha_init = 1.0 critic_hidden_layer_size = (256, 256) policy_hidden_layer_size = (256, 256) fix_alpha = False normalize_observations = False num_loops = 10 print_freq = 1
In [ ]:
Copied!
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)
# Initialize environments env = environments.create( env_name=env_name, batch_size=env_batch_size * population_size_per_device, episode_length=episode_length, auto_reset=True, ) eval_env = environments.create( env_name=env_name, batch_size=env_batch_size * population_size_per_device, episode_length=episode_length, auto_reset=True, )
In [ ]:
Copied!
@jax.jit
def init_environments(key):

    env_states = jax.jit(env.reset)(rng=key)
    eval_env_first_states = jax.jit(eval_env.reset)(rng=key)

    reshape_fn = jax.jit(
        lambda tree: jax.tree.map(
            lambda x: jnp.reshape(
                x,
                (
                    population_size_per_device,
                    env_batch_size,
                )
                + x.shape[1:],
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)
    eval_env_first_states = reshape_fn(eval_env_first_states)

    return env_states, eval_env_first_states
@jax.jit def init_environments(key): env_states = jax.jit(env.reset)(rng=key) eval_env_first_states = jax.jit(eval_env.reset)(rng=key) reshape_fn = jax.jit( lambda tree: jax.tree.map( lambda x: jnp.reshape( x, ( population_size_per_device, env_batch_size, ) + x.shape[1:], ), tree, ), ) env_states = reshape_fn(env_states) eval_env_first_states = reshape_fn(eval_env_first_states) return env_states, eval_env_first_states
In [ ]:
Copied!
key = jax.random.key(seed)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
env_states, eval_env_first_states = jax.pmap(
    init_environments, axis_name="p", devices=devices
)(keys)
key = jax.random.key(seed) key, *keys = jax.random.split(key, num=1 + num_devices) keys = jnp.stack(keys) env_states, eval_env_first_states = jax.pmap( init_environments, axis_name="p", devices=devices )(keys)
In [ ]:
Copied!
# get agent
config = PBTSacConfig(
    batch_size=batch_size,
    episode_length=episode_length,
    tau=tau,
    normalize_observations=normalize_observations,
    alpha_init=alpha_init,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
    fix_alpha=fix_alpha,
)

agent = PBTSAC(config=config, action_size=env.action_size)
# get agent config = PBTSacConfig( batch_size=batch_size, episode_length=episode_length, tau=tau, normalize_observations=normalize_observations, alpha_init=alpha_init, critic_hidden_layer_size=critic_hidden_layer_size, policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, ) agent = PBTSAC(config=config, action_size=env.action_size)
In [ ]:
Copied!
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)

# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)

training_states, replay_buffers = jax.pmap(
    agent_init_fn, axis_name="p", devices=devices
)(keys)
# get the initial training states and replay buffers agent_init_fn = agent.get_init_fn( population_size=population_size_per_device, action_size=env.action_size, observation_size=env.observation_size, buffer_size=buffer_size, ) # Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647 keys = jax.random.key_data(keys) training_states, replay_buffers = jax.pmap( agent_init_fn, axis_name="p", devices=devices )(keys)
In [ ]:
Copied!
# get eval policy function
eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)
# get eval policy function eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)
In [ ]:
Copied!
# eval policy before training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns = jnp.reshape(population_returns, (population_size,))
print(
    f"Evaluation over {env_batch_size} episodes,"
    f" Population mean return: {jnp.mean(population_returns)},"
    f" max return: {jnp.max(population_returns)}"
)
# eval policy before training population_returns, _ = eval_policy(training_states, eval_env_first_states) population_returns = jnp.reshape(population_returns, (population_size,)) print( f"Evaluation over {env_batch_size} episodes," f" Population mean return: {jnp.mean(population_returns)}," f" max return: {jnp.max(population_returns)}" )
In [ ]:
Copied!
# get training function
num_iterations = num_steps // env_batch_size

train_fn = agent.get_train_fn(
    env=env,
    num_iterations=num_iterations,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
)
train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)
# get training function num_iterations = num_steps // env_batch_size train_fn = agent.get_train_fn( env=env, num_iterations=num_iterations, env_batch_size=env_batch_size, grad_updates_per_step=grad_updates_per_step, ) train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)
In [ ]:
Copied!
pbt = PBT(
    population_size=population_size,
    num_best_to_replace_from=num_best_to_replace_from // num_devices,
    num_worse_to_replace=num_worse_to_replace // num_devices,
)
select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)
pbt = PBT( population_size=population_size, num_best_to_replace_from=num_best_to_replace_from // num_devices, num_worse_to_replace=num_worse_to_replace // num_devices, ) select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)
In [ ]:
Copied!
@jax.jit
def unshard_fn(sharded_tree):
    tree = jax.tree.map(lambda x: jax.device_put(x, "cpu"), sharded_tree)
    tree = jax.tree.map(
        lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree
    )
    return tree
@jax.jit def unshard_fn(sharded_tree): tree = jax.tree.map(lambda x: jax.device_put(x, "cpu"), sharded_tree) tree = jax.tree.map( lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree ) return tree
In [ ]:
Copied!
for i in tqdm(range(num_loops), total=num_loops):

    # Update for num_steps
    (training_states, env_states, replay_buffers), metrics = train_fn(
        training_states, env_states, replay_buffers
    )

    # Eval policy after training
    population_returns, _ = eval_policy(training_states, eval_env_first_states)
    population_returns_flatten = jnp.reshape(population_returns, (population_size,))

    if i % print_freq == 0:
        print(
            f"Evaluation over {env_batch_size} episodes,"
            f" Population mean return: {jnp.mean(population_returns_flatten)},"
            f" max return: {jnp.max(population_returns_flatten)}"
        )

    # PBT selection
    if i < (num_loops-1):
        training_states, replay_buffers = select_fn(
            keys, population_returns, training_states, replay_buffers
        )
for i in tqdm(range(num_loops), total=num_loops): # Update for num_steps (training_states, env_states, replay_buffers), metrics = train_fn( training_states, env_states, replay_buffers ) # Eval policy after training population_returns, _ = eval_policy(training_states, eval_env_first_states) population_returns_flatten = jnp.reshape(population_returns, (population_size,)) if i % print_freq == 0: print( f"Evaluation over {env_batch_size} episodes," f" Population mean return: {jnp.mean(population_returns_flatten)}," f" max return: {jnp.max(population_returns_flatten)}" ) # PBT selection if i < (num_loops-1): training_states, replay_buffers = select_fn( keys, population_returns, training_states, replay_buffers )

Visualize learnt behaviors¶

In [ ]:
Copied!
training_states = unshard_fn(training_states)
best_idx = jnp.argmax(population_returns)
best_training_state = jax.tree.map(lambda x: x[best_idx], training_states)
training_states = unshard_fn(training_states) best_idx = jnp.argmax(population_returns) best_training_state = jax.tree.map(lambda x: x[best_idx], training_states)
In [ ]:
Copied!
env = environments.create(env_name, episode_length=episode_length)
env = environments.create(env_name, episode_length=episode_length)
In [ ]:
Copied!
play_step_fn = jax.pmap(
    functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
    axis_name="p",
    devices=devices[:1],
)
play_step_fn = jax.pmap( functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True), axis_name="p", devices=devices[:1], )
In [ ]:
Copied!
training_state = best_training_state
training_state = best_training_state
In [ ]:
Copied!
rollout = []

key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)

training_state, env_state = jax.tree.map(
    lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)

for _ in range(episode_length):

    rollout.append(env_state)
    env_state, training_state, _ = play_step_fn(env_state, training_state)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [] key, subkey = jax.random.split(key) env_state = jax.jit(env.reset)(rng=subkey) training_state, env_state = jax.tree.map( lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state) ) for _ in range(episode_length): rollout.append(env_state) env_state, training_state, _ = play_step_fn(env_state, training_state) print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
rollout = [
    jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
    for env_state in rollout
]
rollout = [ jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state) for env_state in rollout ]
In [ ]:
Copied!
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
Previous Next

Built with MkDocs using a theme provided by Read the Docs.