QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
    • Installation
    • Hyperparameters choice
    • Init environment and replay buffer
    • Define the config, instantiate and initialize DADS
    • Define the skills and the policy evaluation function
    • Warmstart the buffer
    • Prepare last utils for the training loop
    • Train
    • Plot the trajectories of the skills at the end of the training
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Training DADS with JAX
  • Edit on QDax

Open In Colab

Training DADS with JAX¶

This notebook shows how to use QDax to train DADS on a Brax environment. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define an environment
  • how to define a replay buffer
  • how to create a dads instance
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the final trajectories learned

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import os

from IPython.display import clear_output
import functools

import jax
import jax.numpy as jnp

import qdax.tasks.brax as environments
from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState
from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer
from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer

from qdax.utils.plotting import plot_skills_trajectory

from IPython.display import HTML
import os from IPython.display import clear_output import functools import jax import jax.numpy as jnp import qdax.tasks.brax as environments from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer from qdax.utils.plotting import plot_skills_trajectory from IPython.display import HTML

Hyperparameters choice¶

Most hyperparameters are similar to those introduced in SAC paper, DIAYN paper and DADS paper.

The parameter descriptor_full_state is less straightforward, it concerns the information used for diversity seeking and dynamics. In DADS, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When descriptor_full_state is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. In the future, we will add an option to use a prior function directly on the full state.

In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
env_name = 'ant_omni' #@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
seed = 0 #@param {type:"integer"}
env_batch_size = 250 #@param {type:"integer"}
num_steps = 1000000 #@param {type:"integer"}
warmup_steps = 0 #@param {type:"integer"}
buffer_size = 1000000 #@param {type:"integer"}

# SAC config
batch_size = 256 #@param {type:"integer"}
episode_length = 100 #@param {type:"integer"}
grad_updates_per_step = 0.25 #@param {type:"number"}
tau = 0.005 #@param {type:"number"}
learning_rate = 3e-4 #@param {type:"number"}
alpha_init = 1.0 #@param {type:"number"}
discount = 0.97 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
policy_hidden_layer_size = (256, 256) #@param {type:"raw"}
fix_alpha = False #@param {type:"boolean"}
normalize_observations = False #@param {type:"boolean"}
# DADS config
num_skills = 5 #@param {type:"integer"}
dynamics_update_freq = 1 #@param {type:"integer"}
normalize_target = True #@param {type:"boolean"}
descriptor_full_state = False #@param {type:"boolean"}
#@markdown ---
#@title QD Training Definitions Fields #@markdown --- env_name = 'ant_omni' #@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni'] seed = 0 #@param {type:"integer"} env_batch_size = 250 #@param {type:"integer"} num_steps = 1000000 #@param {type:"integer"} warmup_steps = 0 #@param {type:"integer"} buffer_size = 1000000 #@param {type:"integer"} # SAC config batch_size = 256 #@param {type:"integer"} episode_length = 100 #@param {type:"integer"} grad_updates_per_step = 0.25 #@param {type:"number"} tau = 0.005 #@param {type:"number"} learning_rate = 3e-4 #@param {type:"number"} alpha_init = 1.0 #@param {type:"number"} discount = 0.97 #@param {type:"number"} reward_scaling = 1.0 #@param {type:"number"} critic_hidden_layer_size = (256, 256) #@param {type:"raw"} policy_hidden_layer_size = (256, 256) #@param {type:"raw"} fix_alpha = False #@param {type:"boolean"} normalize_observations = False #@param {type:"boolean"} # DADS config num_skills = 5 #@param {type:"integer"} dynamics_update_freq = 1 #@param {type:"integer"} normalize_target = True #@param {type:"boolean"} descriptor_full_state = False #@param {type:"boolean"} #@markdown ---

Init environment and replay buffer¶

Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation.

In [ ]:
Copied!
# Initialize environments
assert (
    env_batch_size % num_skills == 0
), "Parameter env_batch_size should be a multiple of num_skills"
num_env_per_skill = env_batch_size // num_skills

env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size,
    episode_length=episode_length,
    auto_reset=True,
    eval_metrics=True,
)

key = jax.random.key(seed)

key, subkey_1, subkey_2 = jax.random.split(key, 3)
env_state = jax.jit(env.reset)(rng=subkey_1)
eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2)

# Initialize buffer
dummy_transition = QDTransition.init_dummy(
    observation_dim=env.observation_size + num_skills,
    action_dim=env.action_size,
    descriptor_dim=env.descriptor_length,
)
replay_buffer = ReplayBuffer.init(
    buffer_size=buffer_size, transition=dummy_transition
)
# Initialize environments assert ( env_batch_size % num_skills == 0 ), "Parameter env_batch_size should be a multiple of num_skills" num_env_per_skill = env_batch_size // num_skills env = environments.create( env_name=env_name, batch_size=env_batch_size, episode_length=episode_length, auto_reset=True, ) eval_env = environments.create( env_name=env_name, batch_size=env_batch_size, episode_length=episode_length, auto_reset=True, eval_metrics=True, ) key = jax.random.key(seed) key, subkey_1, subkey_2 = jax.random.split(key, 3) env_state = jax.jit(env.reset)(rng=subkey_1) eval_env_first_state = jax.jit(eval_env.reset)(rng=subkey_2) # Initialize buffer dummy_transition = QDTransition.init_dummy( observation_dim=env.observation_size + num_skills, action_dim=env.action_size, descriptor_dim=env.descriptor_length, ) replay_buffer = ReplayBuffer.init( buffer_size=buffer_size, transition=dummy_transition )

Define the config, instantiate and initialize DADS¶

In [ ]:
Copied!
dads_config = DadsConfig(
    # SAC config
    batch_size=batch_size,
    episode_length=episode_length,
    tau=tau,
    normalize_observations=normalize_observations,
    learning_rate=learning_rate,
    alpha_init=alpha_init,
    discount=discount,
    reward_scaling=reward_scaling,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
    fix_alpha=fix_alpha,
    # DADS config
    num_skills=num_skills,
    descriptor_full_state=descriptor_full_state,
    omit_input_dynamics_dim=env.descriptor_length,
    dynamics_update_freq=dynamics_update_freq,
    normalize_target=normalize_target,
)

if descriptor_full_state:
    descriptor_size = env.observation_size
else:
    descriptor_size = env.descriptor_length

# define an instance of DADS
dads = DADS(
    config=dads_config,
    action_size=env.action_size,
    descriptor_size=descriptor_size
)

# get the initial training state
key, subkey = jax.random.split(key)
training_state = dads.init(
    subkey,
    action_size=env.action_size,
    observation_size=env.observation_size,
    descriptor_size=descriptor_size,
)
dads_config = DadsConfig( # SAC config batch_size=batch_size, episode_length=episode_length, tau=tau, normalize_observations=normalize_observations, learning_rate=learning_rate, alpha_init=alpha_init, discount=discount, reward_scaling=reward_scaling, critic_hidden_layer_size=critic_hidden_layer_size, policy_hidden_layer_size=policy_hidden_layer_size, fix_alpha=fix_alpha, # DADS config num_skills=num_skills, descriptor_full_state=descriptor_full_state, omit_input_dynamics_dim=env.descriptor_length, dynamics_update_freq=dynamics_update_freq, normalize_target=normalize_target, ) if descriptor_full_state: descriptor_size = env.observation_size else: descriptor_size = env.descriptor_length # define an instance of DADS dads = DADS( config=dads_config, action_size=env.action_size, descriptor_size=descriptor_size ) # get the initial training state key, subkey = jax.random.split(key) training_state = dads.init( subkey, action_size=env.action_size, observation_size=env.observation_size, descriptor_size=descriptor_size, )

Define the skills and the policy evaluation function¶

In [ ]:
Copied!
# replications of the same skill are evaluated in parallel
skills = jnp.concatenate(
    [jnp.eye(num_skills)] * num_env_per_skill,
    axis=0,
)

# Make play_step functions scannable by passing static args beforehand
play_eval_step = functools.partial(
    dads.play_step_fn,
    deterministic=True,
    env=eval_env,
    skills=skills,
    evaluation=True, # needed by normalizatoin mechanism
)

play_step = functools.partial(
    dads.play_step_fn,
    skills=skills,
    env=env,
    deterministic=False,
)

eval_policy = functools.partial(
    dads.eval_policy_fn,
    play_step_fn=play_eval_step,
    eval_env_first_state=eval_env_first_state,
    env_batch_size=env_batch_size,
)
# replications of the same skill are evaluated in parallel skills = jnp.concatenate( [jnp.eye(num_skills)] * num_env_per_skill, axis=0, ) # Make play_step functions scannable by passing static args beforehand play_eval_step = functools.partial( dads.play_step_fn, deterministic=True, env=eval_env, skills=skills, evaluation=True, # needed by normalizatoin mechanism ) play_step = functools.partial( dads.play_step_fn, skills=skills, env=env, deterministic=False, ) eval_policy = functools.partial( dads.eval_policy_fn, play_step_fn=play_eval_step, eval_env_first_state=eval_env_first_state, env_batch_size=env_batch_size, )

Warmstart the buffer¶

One can fill the replay buffer before the beginning of the training to reduce instabilities in the first steps of the training. This step is not required at all!

In [ ]:
Copied!
# warmstart the buffer
replay_buffer, env_state, training_state = warmstart_buffer(
    replay_buffer=replay_buffer,
    training_state=training_state,
    env_state=env_state,
    num_warmstart_steps=warmup_steps,
    env_batch_size=env_batch_size,
    play_step_fn=play_step,
)
# warmstart the buffer replay_buffer, env_state, training_state = warmstart_buffer( replay_buffer=replay_buffer, training_state=training_state, env_state=env_state, num_warmstart_steps=warmup_steps, env_batch_size=env_batch_size, play_step_fn=play_step, )

Prepare last utils for the training loop¶

Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeated several times. Most of the differences are captured inside the play_step and in the update functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the do_iteration_fn that can be used by most of them. In the training script, the user just has to partial the function to give play_step, update plus a few other parameter.

In [ ]:
Copied!
from typing import Tuple, Any
from brax.envs import State as EnvState

total_num_iterations = num_steps // env_batch_size

# fix static arguments - prepare for scan
do_iteration = functools.partial(
    do_iteration_fn,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
    play_step_fn=play_step,
    update_fn=dads.update,
)

# define a function that enables do_iteration to be scanned
@jax.jit
def _scan_do_iteration(
    carry: Tuple[DadsTrainingState, EnvState, ReplayBuffer],
    unused_arg: Any,
) -> Tuple[Tuple[DadsTrainingState, EnvState, ReplayBuffer], Any]:
    (
        training_state,
        env_state,
        replay_buffer,
        metrics,
    ) = do_iteration(*carry)
    return (training_state, env_state, replay_buffer), metrics
from typing import Tuple, Any from brax.envs import State as EnvState total_num_iterations = num_steps // env_batch_size # fix static arguments - prepare for scan do_iteration = functools.partial( do_iteration_fn, env_batch_size=env_batch_size, grad_updates_per_step=grad_updates_per_step, play_step_fn=play_step, update_fn=dads.update, ) # define a function that enables do_iteration to be scanned @jax.jit def _scan_do_iteration( carry: Tuple[DadsTrainingState, EnvState, ReplayBuffer], unused_arg: Any, ) -> Tuple[Tuple[DadsTrainingState, EnvState, ReplayBuffer], Any]: ( training_state, env_state, replay_buffer, metrics, ) = do_iteration(*carry) return (training_state, env_state, replay_buffer), metrics

Train¶

Training loop: this is a scan of the do_iteration_fn function.

In [ ]:
Copied!
# Main loop
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
    _scan_do_iteration,
    (training_state, env_state, replay_buffer),
    (),
    length=total_num_iterations,
)
# Main loop (training_state, env_state, replay_buffer), metrics = jax.lax.scan( _scan_do_iteration, (training_state, env_state, replay_buffer), (), length=total_num_iterations, )

Plot the trajectories of the skills at the end of the training¶

This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position.

In [ ]:
Copied!
# Evaluation part
true_return, true_returns, diversity_returns, state_desc = eval_policy(
    training_state=training_state
)
# Evaluation part true_return, true_returns, diversity_returns, state_desc = eval_policy( training_state=training_state )
In [ ]:
Copied!
# plot the trajectory of the skills
fig, ax = plot_skills_trajectory(
    trajectories=state_desc.T,
    skills=skills,
    min_values=[-20, -20],
    max_values=[20, 20],
)
# plot the trajectory of the skills fig, ax = plot_skills_trajectory( trajectories=state_desc.T, skills=skills, min_values=[-20, -20], max_values=[20, 20], )

Visualize the skills in the physical simulation¶

Choose a skill¶

In [ ]:
Copied!
my_skill = 0
my_skill = 0
In [ ]:
Copied!
my_params = training_state.policy_params

possible_skills = jnp.eye(num_skills)
skill = possible_skills[my_skill]
my_params = training_state.policy_params possible_skills = jnp.eye(num_skills) skill = possible_skills[my_skill]

Create an environment and jit the step and inference functions¶

In [ ]:
Copied!
# create an environment that is not vectorized
visual_env = environments.create(
    env_name=env_name,
    episode_length=episode_length,
    auto_reset=True,
)

# jit reset/step/inference functions
jit_env_reset = jax.jit(visual_env.reset)
jit_env_step = jax.jit(visual_env.step)

@jax.jit
def jit_inference_fn(params, observation, key):
    obs = jnp.concatenate([observation, skill], axis=0)
    action = dads.select_action(obs, params, key, deterministic=True)
    return action
# create an environment that is not vectorized visual_env = environments.create( env_name=env_name, episode_length=episode_length, auto_reset=True, ) # jit reset/step/inference functions jit_env_reset = jax.jit(visual_env.reset) jit_env_step = jax.jit(visual_env.step) @jax.jit def jit_inference_fn(params, observation, key): obs = jnp.concatenate([observation, skill], axis=0) action = dads.select_action(obs, params, key, deterministic=True) return action

Rollout in the environment and visualize¶

In [ ]:
Copied!
rollout = []
key = jax.random.key(seed=1)
state = jit_env_reset(rng=key)
while not state.done:
    rollout.append(state)
    key, subkey = jax.random.split(key)
    action = jit_inference_fn(my_params, state.obs, subkey)
    state = jit_env_step(state, action)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [] key = jax.random.key(seed=1) state = jit_env_reset(rng=key) while not state.done: rollout.append(state) key, subkey = jax.random.split(key) action = jit_inference_fn(my_params, state.obs, subkey) state = jit_env_step(state, action) print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
HTML(html.render(visual_env.sys, [s.qp for s in rollout[:500]]))
HTML(html.render(visual_env.sys, [s.qp for s in rollout[:500]]))
Previous Next

Built with MkDocs using a theme provided by Read the Docs.