Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install QDax from PyPI:
In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import functools
import time
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import optax
from IPython.display import HTML
from tqdm import tqdm
import qdax.tasks.brax as environments
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig, PBTSacTrainingState
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig
from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results
import functools
import time
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
import optax
from IPython.display import HTML
from tqdm import tqdm
import qdax.tasks.brax as environments
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig, PBTSacTrainingState
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig
from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn
from qdax.custom_types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results
In [ ]:
Copied!
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_platform_name", "cpu")
In [ ]:
Copied!
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
# Get devices (change gpu by tpu if needed)
devices = jax.devices('gpu')
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
In [ ]:
Copied!
env_name = "anttrap"
seed = 0
# SAC config
batch_size = 256
episode_length = 1000
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256)
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False
# Emitter config
buffer_size = 100000
pg_population_size_per_device = 10
ga_population_size_per_device = 30
num_training_iterations = 10000
env_batch_size = 250
grad_updates_per_step = 1.0
iso_sigma = 0.005
line_sigma = 0.05
fraction_best_to_replace_from = 0.1
fraction_to_replace_from_best = 0.2
fraction_to_replace_from_samples = 0.4
eval_env_batch_size = 1
# MAP-Elites config
num_init_cvt_samples = 50000
num_centroids = 128
log_period = 1
num_loops = 10
env_name = "anttrap"
seed = 0
# SAC config
batch_size = 256
episode_length = 1000
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256)
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False
# Emitter config
buffer_size = 100000
pg_population_size_per_device = 10
ga_population_size_per_device = 30
num_training_iterations = 10000
env_batch_size = 250
grad_updates_per_step = 1.0
iso_sigma = 0.005
line_sigma = 0.05
fraction_best_to_replace_from = 0.1
fraction_to_replace_from_best = 0.2
fraction_to_replace_from_samples = 0.4
eval_env_batch_size = 1
# MAP-Elites config
num_init_cvt_samples = 50000
num_centroids = 128
log_period = 1
num_loops = 10
In [ ]:
Copied!
# Initialize environments
env = environments.create(
env_name=env_name,
batch_size=env_batch_size * pg_population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
eval_env = environments.create(
env_name=env_name,
batch_size=eval_env_batch_size,
episode_length=episode_length,
auto_reset=True,
)
# Initialize environments
env = environments.create(
env_name=env_name,
batch_size=env_batch_size * pg_population_size_per_device,
episode_length=episode_length,
auto_reset=True,
)
eval_env = environments.create(
env_name=env_name,
batch_size=eval_env_batch_size,
episode_length=episode_length,
auto_reset=True,
)
In [ ]:
Copied!
min_descriptor, max_descriptor = env.descriptor_limits
min_descriptor, max_descriptor = env.descriptor_limits
In [ ]:
Copied!
key = jax.random.key(seed)
key, subkey_1, subkey_2 = jax.random.split(key, 3)
env_states = jax.jit(env.reset)(rng=subkey_1)
eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)
key = jax.random.key(seed)
key, subkey_1, subkey_2 = jax.random.split(key, 3)
env_states = jax.jit(env.reset)(rng=subkey_1)
eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey_2)
In [ ]:
Copied!
# get agent
config = PBTSacConfig(
batch_size=batch_size,
episode_length=episode_length,
tau=tau,
normalize_observations=normalize_observations,
alpha_init=alpha_init,
critic_hidden_layer_size=critic_hidden_layer_size,
policy_hidden_layer_size=policy_hidden_layer_size,
fix_alpha=fix_alpha,
)
agent = PBTSAC(config=config, action_size=env.action_size)
# get agent
config = PBTSacConfig(
batch_size=batch_size,
episode_length=episode_length,
tau=tau,
normalize_observations=normalize_observations,
alpha_init=alpha_init,
critic_hidden_layer_size=critic_hidden_layer_size,
policy_hidden_layer_size=policy_hidden_layer_size,
fix_alpha=fix_alpha,
)
agent = PBTSAC(config=config, action_size=env.action_size)
In [ ]:
Copied!
# init emitter
emitter_config = PBTEmitterConfig(
buffer_size=buffer_size,
num_training_iterations=num_training_iterations // env_batch_size,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
pg_population_size_per_device=pg_population_size_per_device,
ga_population_size_per_device=ga_population_size_per_device,
num_devices=num_devices,
fraction_best_to_replace_from=fraction_best_to_replace_from,
fraction_to_replace_from_best=fraction_to_replace_from_best,
fraction_to_replace_from_samples=fraction_to_replace_from_samples,
fraction_sort_exchange=0.1,
)
# init emitter
emitter_config = PBTEmitterConfig(
buffer_size=buffer_size,
num_training_iterations=num_training_iterations // env_batch_size,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
pg_population_size_per_device=pg_population_size_per_device,
ga_population_size_per_device=ga_population_size_per_device,
num_devices=num_devices,
fraction_best_to_replace_from=fraction_best_to_replace_from,
fraction_to_replace_from_best=fraction_to_replace_from_best,
fraction_to_replace_from_samples=fraction_to_replace_from_samples,
fraction_sort_exchange=0.1,
)
In [ ]:
Copied!
variation_fn = functools.partial(
sac_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma
)
variation_fn = functools.partial(
sac_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma
)
In [ ]:
Copied!
emitter = PBTEmitter(
pbt_agent=agent,
config=emitter_config,
env=env,
variation_fn=variation_fn,
)
emitter = PBTEmitter(
pbt_agent=agent,
config=emitter_config,
env=env,
variation_fn=variation_fn,
)
In [ ]:
Copied!
# get scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)
def scoring_function(genotypes, key):
population_size = jax.tree.leaves(genotypes)[0].shape[0]
first_states = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states
)
first_states = jax.tree.map(
lambda x: jnp.repeat(x, population_size, axis=0), first_states
)
population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)
return population_returns, population_descriptors, {}
# get scoring function
descriptor_extraction_fn = environments.descriptor_extractor[env_name]
eval_policy = agent.get_eval_qd_fn(eval_env, descriptor_extraction_fn=descriptor_extraction_fn)
def scoring_function(genotypes, key):
population_size = jax.tree.leaves(genotypes)[0].shape[0]
first_states = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states
)
first_states = jax.tree.map(
lambda x: jnp.repeat(x, population_size, axis=0), first_states
)
population_returns, population_descriptors, _, _ = eval_policy(genotypes, first_states)
return population_returns, population_descriptors, {}
In [ ]:
Copied!
# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]
# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=reward_offset * episode_length,
)
# Get the MAP-Elites algorithm
map_elites = DistributedMAPElites(
scoring_function=scoring_function,
emitter=emitter,
metrics_function=metrics_function,
)
# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]
# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=reward_offset * episode_length,
)
# Get the MAP-Elites algorithm
map_elites = DistributedMAPElites(
scoring_function=scoring_function,
emitter=emitter,
metrics_function=metrics_function,
)
In [ ]:
Copied!
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
num_descriptors=env.descriptor_length,
num_init_cvt_samples=num_init_cvt_samples,
num_centroids=num_centroids,
minval=min_descriptor,
maxval=max_descriptor,
key=subkey,
)
key, subkey = jax.random.split(key)
centroids = compute_cvt_centroids(
num_descriptors=env.descriptor_length,
num_init_cvt_samples=num_init_cvt_samples,
num_centroids=num_centroids,
minval=min_descriptor,
maxval=max_descriptor,
key=subkey,
)
In [ ]:
Copied!
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
In [ ]:
Copied!
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
population_size=pg_population_size_per_device + ga_population_size_per_device,
action_size=env.action_size,
observation_size=env.observation_size,
buffer_size=buffer_size,
)
# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
population_size=pg_population_size_per_device + ga_population_size_per_device,
action_size=env.action_size,
observation_size=env.observation_size,
buffer_size=buffer_size,
)
# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647
keys = jax.random.key_data(keys)
training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)
In [ ]:
Copied!
# empty optimizers states to avoid too heavy repertories
training_states = jax.pmap(
jax.vmap(training_states.__class__.empty_optimizers_states),
axis_name="p",
devices=devices,
)(training_states)
# initialize map-elites
repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn(
devices=devices, centroids=centroids
)(genotypes=training_states, key=keys)
# empty optimizers states to avoid too heavy repertories
training_states = jax.pmap(
jax.vmap(training_states.__class__.empty_optimizers_states),
axis_name="p",
devices=devices,
)(training_states)
# initialize map-elites
repertoire, emitter_state, init_metrics = map_elites.get_distributed_init_fn(
devices=devices, centroids=centroids
)(genotypes=training_states, key=keys)
In [ ]:
Copied!
update_fn = map_elites.get_distributed_update_fn(
num_iterations=log_period, devices=devices
)
update_fn = map_elites.get_distributed_update_fn(
num_iterations=log_period, devices=devices
)
In [ ]:
Copied!
env_step_multiplier = (
(pg_population_size_per_device + ga_population_size_per_device)
* eval_env_batch_size
* episode_length
+ num_training_iterations * pg_population_size_per_device
) * num_devices
env_step_multiplier = (
(pg_population_size_per_device + ga_population_size_per_device)
* eval_env_batch_size
* episode_length
+ num_training_iterations * pg_population_size_per_device
) * num_devices
In [ ]:
Copied!
all_metrics = {}
# Log init_metrics
for _key, _value in init_metrics.items():
all_metrics[_key] = _value
for i in tqdm(range(num_loops // log_period), total=num_loops // log_period):
start_time = time.time()
repertoire, emitter_state, metrics = update_fn(
repertoire, emitter_state, keys
)
metrics_cpu = jax.tree.map(
lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics
)
timelapse = time.time() - start_time
# log metrics
for k, v in metrics_cpu.items():
# take all values
if k in all_metrics.keys():
all_metrics[k] = jnp.concatenate([all_metrics[k], v])
else:
all_metrics[k] = v
all_metrics = {}
# Log init_metrics
for _key, _value in init_metrics.items():
all_metrics[_key] = _value
for i in tqdm(range(num_loops // log_period), total=num_loops // log_period):
start_time = time.time()
repertoire, emitter_state, metrics = update_fn(
repertoire, emitter_state, keys
)
metrics_cpu = jax.tree.map(
lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics
)
timelapse = time.time() - start_time
# log metrics
for k, v in metrics_cpu.items():
# take all values
if k in all_metrics.keys():
all_metrics[k] = jnp.concatenate([all_metrics[k], v])
else:
all_metrics[k] = v
In [ ]:
Copied!
# Create the performance evolution plots and visualize final grid
repertoire_cpu = jax.tree.map(
lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], repertoire
)
num_loops_with_init = num_loops + 1
env_steps = (jnp.arange(num_loops_with_init) + 1) * env_step_multiplier
fig, axes = plot_map_elites_results(
env_steps=env_steps,
metrics=all_metrics,
repertoire=repertoire_cpu,
min_descriptor=min_descriptor,
max_descriptor=max_descriptor,
)
# Create the performance evolution plots and visualize final grid
repertoire_cpu = jax.tree.map(
lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], repertoire
)
num_loops_with_init = num_loops + 1
env_steps = (jnp.arange(num_loops_with_init) + 1) * env_step_multiplier
fig, axes = plot_map_elites_results(
env_steps=env_steps,
metrics=all_metrics,
repertoire=repertoire_cpu,
min_descriptor=min_descriptor,
max_descriptor=max_descriptor,
)
Visualize learnt behaviors¶
In [ ]:
Copied!
# Evaluate best individual of the repertoire
best_idx = jnp.argmax(repertoire_cpu.fitnesses)
best_fitness = jnp.max(repertoire_cpu.fitnesses)
best_descriptor = repertoire_cpu.descriptors[best_idx]
# Evaluate best individual of the repertoire
best_idx = jnp.argmax(repertoire_cpu.fitnesses)
best_fitness = jnp.max(repertoire_cpu.fitnesses)
best_descriptor = repertoire_cpu.descriptors[best_idx]
In [ ]:
Copied!
repertoire_cpu.descriptors.shape
repertoire_cpu.descriptors.shape
In [ ]:
Copied!
# Evaluate agent that goes the further on the y-axis
# best_idx = jnp.argmax(repertoire.descriptors[:, 0])
# best_fitness = repertoire.fitnesses[best_idx]
# best_descriptor = repertoire.descriptors[best_idx]
# Evaluate agent that goes the further on the y-axis
# best_idx = jnp.argmax(repertoire.descriptors[:, 0])
# best_fitness = repertoire.fitnesses[best_idx]
# best_descriptor = repertoire.descriptors[best_idx]
In [ ]:
Copied!
print(
f"Fitness of the selected agent: {best_fitness:.2f}\n",
f"Descriptor of the selected agent: {best_descriptor}\n",
f"Index in the repertoire of this individual: {best_idx}\n",
)
print(
f"Fitness of the selected agent: {best_fitness:.2f}\n",
f"Descriptor of the selected agent: {best_descriptor}\n",
f"Index in the repertoire of this individual: {best_idx}\n",
)
In [ ]:
Copied!
env = environments.create(env_name, episode_length=episode_length)
env = environments.create(env_name, episode_length=episode_length)
In [ ]:
Copied!
play_step_fn = jax.pmap(
functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
axis_name="p",
devices=devices[:1],
)
play_step_fn = jax.pmap(
functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
axis_name="p",
devices=devices[:1],
)
In [ ]:
Copied!
training_state = jax.tree.map(lambda x: x[best_idx], repertoire_cpu.genotypes)
training_state = jax.tree.map(lambda x: x[best_idx], repertoire_cpu.genotypes)
In [ ]:
Copied!
rollout = []
key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)
training_state, env_state = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)
for _ in range(episode_length):
rollout.append(env_state)
env_state, training_state, _ = play_step_fn(env_state, training_state)
print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = []
key, subkey = jax.random.split(key)
env_state = jax.jit(env.reset)(rng=subkey)
training_state, env_state = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)
for _ in range(episode_length):
rollout.append(env_state)
env_state, training_state, _ = play_step_fn(env_state, training_state)
print(f"The trajectory of this individual contains {len(rollout)} transitions.")
In [ ]:
Copied!
rollout = [
jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
for env_state in rollout
]
rollout = [
jax.tree.map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
for env_state in rollout
]
In [ ]:
Copied!
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))
In [ ]:
Copied!