Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install QDax from PyPI:
In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import functools
import jax
import jax.numpy as jnp
from IPython.display import HTML
from tqdm import tqdm
import qdax.tasks.brax as environments
from qdax.baselines.pbt import PBT
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig
import functools
import jax
import jax.numpy as jnp
from IPython.display import HTML
from tqdm import tqdm
import qdax.tasks.brax as environments
from qdax.baselines.pbt import PBT
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig
In [ ]:
Copied!
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_platform_name", "cpu")
In [ ]:
Copied!
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
In [ ]:
Copied!
env_name = "humanoidtrap"
seed = 0
env_batch_size = 250
population_size_per_device = 10
population_size = population_size_per_device * num_devices
num_steps = 10000
buffer_size = 100000
# PBT Config
num_best_to_replace_from = 1
num_worse_to_replace = 1
# SAC config
batch_size = 256
episode_length = 1000
grad_updates_per_step = 1.0
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256)
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False
num_loops = 10
print_freq = 1
env_name = "humanoidtrap"
seed = 0
env_batch_size = 250
population_size_per_device = 10
population_size = population_size_per_device * num_devices
num_steps = 10000
buffer_size = 100000
# PBT Config
num_best_to_replace_from = 1
num_worse_to_replace = 1
# SAC config
batch_size = 256
episode_length = 1000
grad_updates_per_step = 1.0
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256)
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False
num_loops = 10
print_freq = 1
In [ ]:
Copied!
# Initialize environments
env = environments.create(
env_name=env_name,
batch_size=env_batch_size * population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
eval_env = environments.create(
env_name=env_name,
batch_size=env_batch_size * population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
# Initialize environments
env = environments.create(
env_name=env_name,
batch_size=env_batch_size * population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
eval_env = environments.create(
env_name=env_name,
batch_size=env_batch_size * population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
In [ ]:
Copied!
@jax.jit
def init_environments(key):
env_states = jax.jit(env.reset)(rng=key)
eval_env_first_states = jax.jit(eval_env.reset)(rng=key)
reshape_fn = jax.jit(
lambda tree: jax.tree.map(
lambda x: jnp.reshape(
x,
(
population_size_per_device,
env_batch_size,
)
+ x.shape[1:],
),
tree,
),
)
env_states = reshape_fn(env_states)
eval_env_first_states = reshape_fn(eval_env_first_states)
return env_states, eval_env_first_states
@jax.jit
def init_environments(key):
env_states = jax.jit(env.reset)(rng=key)
eval_env_first_states = jax.jit(eval_env.reset)(rng=key)
reshape_fn = jax.jit(
lambda tree: jax.tree.map(
lambda x: jnp.reshape(
x,
(
population_size_per_device,
env_batch_size,
)
+ x.shape[1:],
),
tree,
),
)
env_states = reshape_fn(env_states)
eval_env_first_states = reshape_fn(eval_env_first_states)
return env_states, eval_env_first_states
In [ ]:
Copied!
key = jax.random.key(seed)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
env_states, eval_env_first_states = jax.pmap(
init_environments, axis_name="p", devices=devices
)(keys)
key = jax.random.key(seed)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
env_states, eval_env_first_states = jax.pmap(
init_environments, axis_name="p", devices=devices
)(keys)
In [ ]:
Copied!
# get agent
config = PBTSacConfig(
batch_size=batch_size,
episode_length=episode_length,
tau=tau,
normalize_observations=normalize_observations,
alpha_init=alpha_init,
critic_hidden_layer_size=critic_hidden_layer_size,
policy_hidden_layer_size=policy_hidden_layer_size,
fix_alpha=fix_alpha,
)
agent = PBTSAC(config=config, action_size=env.action_size)
# get agent
config = PBTSacConfig(
batch_size=batch_size,
episode_length=episode_length,
tau=tau,
normalize_observations=normalize_observations,
alpha_init=alpha_init,
critic_hidden_layer_size=critic_hidden_layer_size,
policy_hidden_layer_size=policy_hidden_layer_size,
fix_alpha=fix_alpha,
)
agent = PBTSAC(config=config, action_size=env.action_size)
In [ ]:
Copied!
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
population_size=population_size_per_device,
action_size=env.action_size,
observation_size=env.observation_size,
buffer_size=buffer_size,
)
# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
training_states, replay_buffers = jax.pmap(
agent_init_fn, axis_name="p", devices=devices
)(keys)
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
population_size=population_size_per_device,
action_size=env.action_size,
observation_size=env.observation_size,
buffer_size=buffer_size,
)
# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
training_states, replay_buffers = jax.pmap(
agent_init_fn, axis_name="p", devices=devices
)(keys)
In [ ]:
Copied!
# get eval policy function
eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)
# get eval policy function
eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)
In [ ]:
Copied!
# eval policy before training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns = jnp.reshape(population_returns, (population_size,))
print(
f"Evaluation over {env_batch_size} episodes,"
f" Population mean return: {jnp.mean(population_returns)},"
f" max return: {jnp.max(population_returns)}"
)
# eval policy before training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns = jnp.reshape(population_returns, (population_size,))
print(
f"Evaluation over {env_batch_size} episodes,"
f" Population mean return: {jnp.mean(population_returns)},"
f" max return: {jnp.max(population_returns)}"
)
In [ ]:
Copied!
# get training function
num_iterations = num_steps // env_batch_size
train_fn = agent.get_train_fn(
env=env,
num_iterations=num_iterations,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
)
train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)
# get training function
num_iterations = num_steps // env_batch_size
train_fn = agent.get_train_fn(
env=env,
num_iterations=num_iterations,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
)
train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)
In [ ]:
Copied!
pbt = PBT(
population_size=population_size,
num_best_to_replace_from=num_best_to_replace_from // num_devices,
num_worse_to_replace=num_worse_to_replace // num_devices,
)
select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)
pbt = PBT(
population_size=population_size,
num_best_to_replace_from=num_best_to_replace_from // num_devices,
num_worse_to_replace=num_worse_to_replace // num_devices,
)
select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)
In [ ]:
Copied!
@jax.jit
def unshard_fn(sharded_tree):
tree = jax.tree.map(lambda x: jax.device_put(x, "cpu"), sharded_tree)
tree = jax.tree.map(
lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree
)
return tree
@jax.jit
def unshard_fn(sharded_tree):
tree = jax.tree.map(lambda x: jax.device_put(x, "cpu"), sharded_tree)
tree = jax.tree.map(
lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree
)
return tree
In [ ]:
Copied!
for i in tqdm(range(num_loops), total=num_loops):
# Update for num_steps
(training_states, env_states, replay_buffers), metrics = train_fn(
training_states, env_states, replay_buffers
)
# Eval policy after training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns_flatten = jnp.reshape(population_returns, (population_size,))
if i % print_freq == 0:
print(
f"Evaluation over {env_batch_size} episodes,"
f" Population mean return: {jnp.mean(population_returns_flatten)},"
f" max return: {jnp.max(population_returns_flatten)}"
)
# PBT selection
if i < (num_loops-1):
training_states, replay_buffers = select_fn(
keys, population_returns, training_states, replay_buffers
)
for i in tqdm(range(num_loops), total=num_loops):
# Update for num_steps
(training_states, env_states, replay_buffers), metrics = train_fn(
training_states, env_states, replay_buffers
)
# Eval policy after training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns_flatten = jnp.reshape(population_returns, (population_size,))
if i % print_freq == 0:
print(
f"Evaluation over {env_batch_size} episodes,"
f" Population mean return: {jnp.mean(population_returns_flatten)},"
f" max return: {jnp.max(population_returns_flatten)}"
)
# PBT selection
if i < (num_loops-1):
training_states, replay_buffers = select_fn(
keys, population_returns, training_states, replay_buffers
)
Visualize learnt behaviors¶
In [ ]:
Copied!
training_states = unshard_fn(training_states)
best_idx = jnp.argmax(population_returns)
best_training_state = jax.tree.map(lambda x: x[best_idx], training_states)
training_states = unshard_fn(training_states)
best_idx = jnp.argmax(population_returns)
best_training_state = jax.tree.map(lambda x: x[best_idx], training_states)
In [ ]:
Copied!
env = environments.create(env_name, episode_length=episode_length)
env = environments.create(env_name, episode_length=episode_length)
In [ ]:
Copied!
play_step_fn = jax.pmap(
functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
axis_name="p",
devices=devices[:1],
)
play_step_fn = jax.pmap(
functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
axis_name="p",
devices=devices[:1],
)
In [ ]:
Copied!
training_state = best_training_state
training_state = best_training_state
In [ ]:
Copied!
rollout = []
key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)
training_state, env_state = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)
for _ in range(episode_length):
rollout.append(env_state)
env_state, training_state, _ = play_step_fn(env_state, training_state)
print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = []
key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)
training_state, env_state = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)
for _ in range(episode_length):
rollout.append(env_state)
env_state, training_state, _ = play_step_fn(env_state, training_state)
print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
rollout = [
jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
for env_state in rollout
]
rollout = [
jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
for env_state in rollout
]
In [ ]:
Copied!
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))