Training DIAYN SMERL with Jax¶
This notebook shows how to use QDax to train DIAYN SMERL on a Brax environment. It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show: - how to define an environment - how to define a replay buffer - how to create a diayn smerl instance - which functions must be defined before training - how to launch a certain number of training steps - how to visualise the final trajectories learned
#@title Installs and Imports
!pip install ipympl |tail -n 1
# %matplotlib widget
# from google.colab import output
# output.enable_custom_widget_manager()
import os
from IPython.display import clear_output
import functools
import jax
import jax.numpy as jnp
try:
import brax
except:
!pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
import brax
try:
import flax
except:
!pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
import flax
try:
import chex
except:
!pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
import chex
try:
import jumanji
except:
!pip install "jumanji==0.3.1"
import jumanji
try:
import haiku
except:
!pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1
import haiku
try:
import qdax
except:
!pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
import qdax
from qdax import environments
from qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig, DiaynTrainingState
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer
from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer
from qdax.utils.plotting import plot_skills_trajectory
from IPython.display import HTML
from brax.v1.io import html
if "COLAB_TPU_ADDR" in os.environ:
from jax.tools import colab_tpu
colab_tpu.setup_tpu()
clear_output()
Hyperparameters choice¶
Most hyperparameters are similar to those introduced in SAC paper, DIAYN paper and SMERL paper.
The parameter descriptor_full_state
is less straightforward, it concerns the information used for diversity seeking and discrimination. In DIAYN, one can use the full state for diversity seeking, but one can also use a prior to focus on an interesting aspect of the state. Actually, priors are often used in experiments, for instance, focusing on the x/y position rather than the full position. When descriptor_full_state
is set to True, it uses the full state, when it is set to False, it uses the 'state descriptor' retrieved by the environment. Hence, it is required that the environment has one. (All the _uni
, _omni
do, same for anttrap
, antmaze
and pointmaze
.) In the future, we will add an option to use a prior function direclty on the full state.
#@title QD Training Definitions Fields
#@markdown ---
env_name = 'anttrap' #@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
seed = 0 #@param {type:"integer"}
env_batch_size = 250 #@param {type:"integer"}
num_steps = 1000000 #@param {type:"integer"}
warmup_steps = 0 #@param {type:"integer"}
buffer_size = 1000000 #@param {type:"integer"}
# SAC config
batch_size = 256 #@param {type:"integer"}
episode_length = 100 #@param {type:"integer"}
grad_updates_per_step = 0.25 #@param {type:"number"}
tau = 0.005 #@param {type:"number"}
learning_rate = 3e-4 #@param {type:"number"}
alpha_init = 1.0 #@param {type:"number"}
discount = 0.97 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
policy_hidden_layer_size = (256, 256) #@param {type:"raw"}
fix_alpha = False #@param {type:"boolean"}
normalize_observations = False #@param {type:"boolean"}
# DIAYN config
num_skills = 5 #@param {type:"integer"}
descriptor_full_state = False #@param {type:"boolean"}
# SMERL specific
diversity_reward_scale = 2.0 #@param {type:"number"}
smerl_target = 800 #@param {type:"number"}
smerl_margin = 800 #@param {type:"number"}
#@markdown ---
Init environment and replay buffer¶
Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation.
For DIAYN SMERL, we need to use a called Trajectory Replay Buffer, because the reward depend on the trajectory accumulated rewards.
# Initialize environments
assert (
env_batch_size % num_skills == 0
), "Parameter env_batch_size should be a multiple of num_skills"
num_env_per_skill = env_batch_size // num_skills
env = environments.create(
env_name=env_name,
batch_size=env_batch_size,
episode_length=episode_length,
auto_reset=True,
)
eval_env = environments.create(
env_name=env_name,
batch_size=env_batch_size,
episode_length=episode_length,
auto_reset=True,
eval_metrics=True,
)
key = jax.random.PRNGKey(seed)
env_state = jax.jit(env.reset)(rng=key)
eval_env_first_state = jax.jit(eval_env.reset)(rng=key)
# Initialize buffer
dummy_transition = QDTransition.init_dummy(
observation_dim=env.observation_size + num_skills,
action_dim=env.action_size,
descriptor_dim=env.behavior_descriptor_length,
)
# Use a trajectory replay buffer
replay_buffer = TrajectoryBuffer.init(
buffer_size=buffer_size,
transition=dummy_transition,
env_batch_size=env_batch_size,
episode_length=episode_length,
)
Define the config, instantiate and initialize DIAYN SMERL¶
diayn_smerl_config = DiaynSmerlConfig(
# SAC config
batch_size=batch_size,
episode_length=episode_length,
tau=tau,
normalize_observations=normalize_observations,
learning_rate=learning_rate,
alpha_init=alpha_init,
discount=discount,
reward_scaling=reward_scaling,
critic_hidden_layer_size=critic_hidden_layer_size,
policy_hidden_layer_size=policy_hidden_layer_size,
fix_alpha=fix_alpha,
# DIAYN config
num_skills=num_skills,
descriptor_full_state=descriptor_full_state,
diversity_reward_scale=diversity_reward_scale,
smerl_margin=smerl_margin,
smerl_target=smerl_target,
)
# define an instance of DIAYN SMERL
diayn_smerl = DIAYNSMERL(config=diayn_smerl_config, action_size=env.action_size)
if descriptor_full_state:
descriptor_size = env.observation_size
else:
descriptor_size = env.behavior_descriptor_length
# get the initial training state
training_state = diayn_smerl.init(
key,
action_size=env.action_size,
observation_size=env.observation_size,
descriptor_size=descriptor_size,
)
Define the skills and the policy evaluation function¶
# replications of the same skill are evaluated in parallel
skills = jnp.concatenate(
[jnp.eye(num_skills)] * num_env_per_skill,
axis=0,
)
# Make play_step functions scannable by passing static args beforehand
play_eval_step = functools.partial(
diayn_smerl.play_step_fn,
skills=skills,
env=eval_env,
deterministic=True,
)
play_step = functools.partial(
diayn_smerl.play_step_fn,
skills=skills,
env=env,
deterministic=False,
)
eval_policy = functools.partial(
diayn_smerl.eval_policy_fn,
play_step_fn=play_eval_step,
eval_env_first_state=eval_env_first_state,
env_batch_size=env_batch_size,
)
Warmstart the buffer¶
One can fill the replay buffer before the beginning of the training to reduce instabilities in the first steps of the training. This step is not required at all!
# warmstart the buffer
replay_buffer, env_state, training_state = warmstart_buffer(
replay_buffer=replay_buffer,
training_state=training_state,
env_state=env_state,
num_warmstart_steps=warmup_steps,
env_batch_size=env_batch_size,
play_step_fn=play_step,
)
Prepare last utils for the training loop¶
Many Reinforcement Learning algorithm have similar training process, that can be divided in a precise training step that is repeated several times. Most of the differences are captured inside the play_step
and in the update
functions. Hence, once those are defined, the iteration works in the same way. For this reason, instead of coding the same function for each algorithm, we have created the do_iteration_fn
that can be used by most of them. In the training script, the user just has to partial the function to give play_step
, update
plus a few other parameters.
from typing import Tuple, Any
from brax.envs import State as EnvState
total_num_iterations = num_steps // env_batch_size
# fix static arguments - prepare for scan
do_iteration = functools.partial(
do_iteration_fn,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
play_step_fn=play_step,
update_fn=diayn_smerl.update,
)
# define a function that enables do_iteration to be scanned
@jax.jit
def _scan_do_iteration(
carry: Tuple[DiaynTrainingState, EnvState, TrajectoryBuffer],
unused_arg: Any,
) -> Tuple[Tuple[DiaynTrainingState, EnvState, TrajectoryBuffer], Any]:
(
training_state,
env_state,
replay_buffer,
metrics,
) = do_iteration(*carry)
return (training_state, env_state, replay_buffer), metrics
Train¶
Training loop: this is a scan of the do_iteration_fn
function.
%%time
# Main loop
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
_scan_do_iteration,
(training_state, env_state, replay_buffer),
(),
length=total_num_iterations,
)
Plot the trajectories of the skills at the end of the training¶
This only works when the state descriptor considered is two-dimensional, and as a real interest only when this state descriptor is the x/y position. Hence, on all "omni" tasks, on pointmaze, anttrap and antmaze.
# Evaluation part
true_return, true_returns, diversity_returns, state_desc = eval_policy(
training_state=training_state
)
# plot the trajectory of the skills
fig, ax = plot_skills_trajectory(
trajectories=state_desc.T,
skills=skills,
min_values=[0, -8],
max_values=[30, 8],
)
Visualize the skills in the physical simulation¶
WARNING: this does not work with "pointmaze"
assert env_name != "pointmaze", "No visualisation available for pointmaze at the moment"
Choose a skill¶
my_skill = 2
my_params = training_state.policy_params
possible_skills = jnp.eye(num_skills)
skill = possible_skills[my_skill]
Create an environment and jit the step and inference functions¶
# create an environment that is not vectorized
visual_env = environments.create(
env_name=env_name,
episode_length=episode_length,
auto_reset=True,
)
# jit reset/step/inference functions
jit_env_reset = jax.jit(visual_env.reset)
jit_env_step = jax.jit(visual_env.step)
@jax.jit
def jit_inference_fn(params, observation, random_key):
obs = jnp.concatenate([observation, skill], axis=0)
action, random_key = diayn_smerl.select_action(obs, params, random_key, deterministic=True)
return action, random_key
Rollout in the environment and visualize¶
rollout = []
random_key = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=random_key)
while not state.done:
rollout.append(state)
action, random_key = jit_inference_fn(my_params, state.obs, random_key)
state = jit_env_step(state, action)
print(f"The trajectory of this individual contains {len(rollout)} transitions.")
HTML(html.render(visual_env.sys, [s.qp for s in rollout[:500]]))