Training a population on Jumanji-Snake with QDax¶
This notebook shows how to use either MAP-Elites or a simple (non-QD) genetic algorithm to train a population of agents that play the game of Snake from Jumanji.
This notebook can be used as an inspiration to interact with other environments from Jumanji.
from functools import partial
from typing import Tuple, Type
import jax
import jax.numpy as jnp
try:
import brax
except:
!pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
import brax
try:
import flax
except:
!pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
import flax
try:
import chex
except:
!pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
import chex
try:
import jumanji
except:
!pip install "jumanji==0.3.1"
import jumanji
try:
import qdax
except:
!pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
import qdax
import functools
import numpy as np
from qdax.baselines.genetic_algorithm import GeneticAlgorithm
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.jumanji_envs import jumanji_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics, default_qd_metrics
Define hyperparameters¶
seed = 0
policy_hidden_layer_sizes = (128, 128)
episode_length = 200
population_size = 100
batch_size = population_size
num_iterations = 5000
iso_sigma = 0.005
line_sigma = 0.05
Instantiate the snake environment¶
# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-v1')
# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
# Interact with the (jit-able) environment
action = env.action_spec().generate_value() # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)
Define the type of policy that will be used to solve the problem¶
# Init a random key
random_key = jax.random.PRNGKey(seed)
# get number of actions
num_actions = env.action_spec().maximum + 1
policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jax.nn.softmax,
)
Utils to interact with the environment¶
Define a way to process the observation and define a way to play a step in the environment, given the parameters of a policy_network.
def observation_processing(observation):
network_input = jnp.ravel(observation)
return network_input
def play_step_fn(
env_state,
timestep,
policy_params,
random_key,
):
"""Play an environment step and return the updated state and the transition.
Everything is deterministic in this simple example.
"""
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(policy_params, network_input)
action = jnp.argmax(proba_action)
state_desc = None
next_state, next_timestep = env.step(env_state, action)
# next_state_desc=next_state.info["state_descriptor"]
next_state_desc = None
transition = QDTransition(
obs=timestep.observation,
next_obs=next_timestep.observation,
rewards=next_timestep.reward,
dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),
actions=action,
truncations=jnp.array(0),
state_desc=state_desc,
next_state_desc=next_state_desc,
)
return next_state, next_timestep, policy_params, random_key, transition
Init a population of policies¶
Also init init states and timesteps
# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
# compute observation size from observation spec
obs_spec = env.observation_spec()
observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))
fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states, init_timesteps = reset_fn(keys)
Define a method to extract behavior descriptor when relevant¶
# Prepare the scoring function
def bd_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray) -> Descriptor:
"""Compute feet contact time proportion.
This function suppose that state descriptor is the feet contact, as it
just computes the mean of the state descriptors given.
"""
# pre-process the observation
observation = jax.vmap(jax.vmap(observation_processing))(data.obs)
# get the mean
mean_observation = jnp.mean(observation, axis=-2)
# project those in [-1, 1]^2
descriptors = jnp.tanh(mean_observation @ linear_projection.T)
return descriptors
# create a random projection to a two dim space
random_key, subkey = jax.random.split(random_key)
linear_projection = jax.random.uniform(
subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)
bd_extraction_fn = functools.partial(
bd_extraction,
linear_projection=linear_projection
)
# define the scoring function
scoring_fn = functools.partial(
jumanji_scoring_function,
init_states=init_states,
init_timesteps=init_timesteps,
episode_length=episode_length,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=bd_extraction_fn,
)
Define the scoring function¶
def scoring_function(
genotypes: jnp.ndarray, random_key: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
fitnesses, _, extra_scores, random_key = scoring_fn(genotypes, random_key)
return fitnesses.reshape(-1, 1), extra_scores, random_key
Define the emitter used¶
# Define emitter
variation_fn = functools.partial(
isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
mutation_fn=None,
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size
)
Define the algorithm used and apply the initial step¶
One can either use a simple genetic algorithm or use MAP-Elites.
use_map_elites = True
if not use_map_elites:
algo_instance = GeneticAlgorithm(
scoring_function=scoring_function,
emitter=mixing_emitter,
metrics_function=default_ga_metrics,
)
repertoire, emitter_state, random_key = algo_instance.init(
init_variables, population_size, random_key
)
else:
# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=0,
)
# Instantiate MAP-Elites
algo_instance = MAPElites(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
)
# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=(50, 50),
minval=-1,
maxval=1,
)
# Compute initial repertoire and emitter state
repertoire, emitter_state, random_key = algo_instance.init(init_variables, centroids, random_key)
Run the optimization loop¶
%%time
# Run the algorithm
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
algo_instance.scan_update,
(repertoire, emitter_state, random_key),
(),
length=num_iterations,
)
metrics["max_fitness"][-1]
repertoire.fitnesses
repertoire.descriptors
# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size
from qdax.utils.plotting import plot_map_elites_results
# create the plots and the grid
fig, axes = plot_map_elites_results(
env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=-1., max_bd=1.
)
Play snake with the best policy¶
Retrieve one of the best policies from the repertoire and show how it does on the Snake environment.
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
print(
f"Best fitness in the repertoire: {best_fitness:.2f}\n",
f"Index in the repertoire of this individual: {best_idx}\n"
)
my_params = jax.tree_util.tree_map(
lambda x: x[best_idx],
repertoire.genotypes
)
init_state = jax.tree_util.tree_map(
lambda x: x[0],
init_states
)
init_timestep = jax.tree_util.tree_map(
lambda x: x[0],
init_timesteps
)
state = jax.tree_util.tree_map(lambda x: x.copy(), init_state)
timestep = jax.tree_util.tree_map(lambda x: x.copy(), init_timestep)
for _ in range(100):
# (Optional) Render the env state
env.render(state)
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(my_params, network_input)
action = jnp.argmax(proba_action)
state, timestep = jax.jit(env.step)(state, action)