Training a population on Jumanji-Snake with QDax¶
This notebook shows how to use either MAP-Elites or a simple (non-QD) genetic algorithm to train a population of agents that play the game of Snake from Jumanji.
This notebook can be used as an inspiration to interact with other environments from Jumanji.
Installation¶
You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:
In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"
Then, install QDax from PyPI:
In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
from functools import partial
from typing import Tuple, Type
import jax
import jax.numpy as jnp
import functools
import jumanji
import numpy as np
from qdax.baselines.genetic_algorithm import GeneticAlgorithm
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.jumanji_envs import jumanji_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics, default_qd_metrics
from functools import partial
from typing import Tuple, Type
import jax
import jax.numpy as jnp
import functools
import jumanji
import numpy as np
from qdax.baselines.genetic_algorithm import GeneticAlgorithm
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.tasks.jumanji_envs import jumanji_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics, default_qd_metrics
Define hyperparameters¶
In [ ]:
Copied!
seed = 0
policy_hidden_layer_sizes = (128, 128)
episode_length = 200
population_size = 100
batch_size = population_size
num_iterations = 1000
iso_sigma = 0.005
line_sigma = 0.05
seed = 0
policy_hidden_layer_sizes = (128, 128)
episode_length = 200
population_size = 100
batch_size = population_size
num_iterations = 1000
iso_sigma = 0.005
line_sigma = 0.05
Instantiate the snake environment¶
In [ ]:
Copied!
# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-v1')
# Reset your (jit-able) environment
key = jax.random.key(seed)
key, subkey = jax.random.split(key)
state, timestep = jax.jit(env.reset)(subkey)
# Interact with the (jit-able) environment
action = env.action_spec().generate_value() # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)
# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-v1')
# Reset your (jit-able) environment
key = jax.random.key(seed)
key, subkey = jax.random.split(key)
state, timestep = jax.jit(env.reset)(subkey)
# Interact with the (jit-able) environment
action = env.action_spec().generate_value() # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)
Define the type of policy that will be used to solve the problem¶
In [ ]:
Copied!
# Get number of actions
num_actions = env.action_spec().maximum + 1
policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jax.nn.softmax,
)
# Get number of actions
num_actions = env.action_spec().maximum + 1
policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
kernel_init=jax.nn.initializers.lecun_uniform(),
final_activation=jax.nn.softmax,
)
Utils to interact with the environment¶
Define a way to process the observation and define a way to play a step in the environment, given the parameters of a policy_network.
In [ ]:
Copied!
def observation_processing(observation):
network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])
return network_input
def play_step_fn(
env_state,
timestep,
policy_params,
key,
):
"""Play an environment step and return the updated state and the transition.
Everything is deterministic in this simple example.
"""
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(policy_params, network_input)
action = jnp.argmax(proba_action)
state_desc = None
next_state, next_timestep = env.step(env_state, action)
# next_state_desc=next_state.info["state_descriptor"]
next_state_desc = None
transition = QDTransition(
obs=timestep.observation,
next_obs=next_timestep.observation,
rewards=next_timestep.reward,
dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),
actions=action,
truncations=jnp.array(0),
state_desc=state_desc,
next_state_desc=next_state_desc,
)
return next_state, next_timestep, policy_params, key, transition
def observation_processing(observation):
network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])
return network_input
def play_step_fn(
env_state,
timestep,
policy_params,
key,
):
"""Play an environment step and return the updated state and the transition.
Everything is deterministic in this simple example.
"""
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(policy_params, network_input)
action = jnp.argmax(proba_action)
state_desc = None
next_state, next_timestep = env.step(env_state, action)
# next_state_desc=next_state.info["state_descriptor"]
next_state_desc = None
transition = QDTransition(
obs=timestep.observation,
next_obs=next_timestep.observation,
rewards=next_timestep.reward,
dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),
actions=action,
truncations=jnp.array(0),
state_desc=state_desc,
next_state_desc=next_state_desc,
)
return next_state, next_timestep, policy_params, key, transition
Init a population of policies¶
Also init init states and timesteps
In [ ]:
Copied!
# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)
# compute observation size from observation spec
obs_spec = env.observation_spec()
observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))
fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
# Create the initial environment states
key, subkey = jax.random.split(key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states, init_timesteps = reset_fn(keys)
# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)
# compute observation size from observation spec
obs_spec = env.observation_spec()
observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))
fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
# Create the initial environment states
key, subkey = jax.random.split(key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states, init_timesteps = reset_fn(keys)
Define a method to extract descriptor when relevant¶
In [ ]:
Copied!
# Prepare the scoring function
def descriptor_extraction(data: QDTransition, mask: jax.Arraylinear_projection: jajax.Array Descriptor:
"""Compute feet contact time proportion.
This function suppose that state descriptor is the feet contact, as it
just computes the mean of the state descriptors given.
"""
# pre-process the observation
observation = jax.vmap(jax.vmap(observation_processing))(data.obs)
# get the mean
mean_observation = jnp.mean(observation, axis=-2)
# project those in [-1, 1]^2
descriptors = jnp.tanh(mean_observation @ linear_projection.T)
return descriptors
# create a random projection to a two dim space
key, subkey = jax.random.split(key)
linear_projection = jax.random.uniform(
subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)
descriptor_extraction_fn = functools.partial(
descriptor_extraction,
linear_projection=linear_projection
)
# define the scoring function
scoring_fn = functools.partial(
jumanji_scoring_function,
init_states=init_states,
init_timesteps=init_timesteps,
episode_length=episode_length,
play_step_fn=play_step_fn,
descriptor_extractor=descriptor_extraction_fn,
)
# Prepare the scoring function
def descriptor_extraction(data: QDTransition, mask: jax.Arraylinear_projection: jajax.Array Descriptor:
"""Compute feet contact time proportion.
This function suppose that state descriptor is the feet contact, as it
just computes the mean of the state descriptors given.
"""
# pre-process the observation
observation = jax.vmap(jax.vmap(observation_processing))(data.obs)
# get the mean
mean_observation = jnp.mean(observation, axis=-2)
# project those in [-1, 1]^2
descriptors = jnp.tanh(mean_observation @ linear_projection.T)
return descriptors
# create a random projection to a two dim space
key, subkey = jax.random.split(key)
linear_projection = jax.random.uniform(
subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)
descriptor_extraction_fn = functools.partial(
descriptor_extraction,
linear_projection=linear_projection
)
# define the scoring function
scoring_fn = functools.partial(
jumanji_scoring_function,
init_states=init_states,
init_timesteps=init_timesteps,
episode_length=episode_length,
play_step_fn=play_step_fn,
descriptor_extractor=descriptor_extraction_fn,
)
Define the scoring function¶
In [ ]:
Copied!
def scoring_function(
genotypes: jax.Arraykey: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
fitnesses, _, extra_scores = scoring_fn(genotypes, key)
return fitnesses.reshape(-1, 1), extra_scores
def scoring_function(
genotypes: jax.Arraykey: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
fitnesses, _, extra_scores = scoring_fn(genotypes, key)
return fitnesses.reshape(-1, 1), extra_scores
Define the emitter used¶
In [ ]:
Copied!
# Define emitter
variation_fn = functools.partial(
isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
mutation_fn=None,
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size
)
# Define emitter
variation_fn = functools.partial(
isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
mutation_fn=None,
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size
)
Define the algorithm used and apply the initial step¶
One can either use a simple genetic algorithm or use MAP-Elites.
In [ ]:
Copied!
use_map_elites = True
if not use_map_elites:
algo_instance = GeneticAlgorithm(
scoring_function=scoring_function,
emitter=mixing_emitter,
metrics_function=default_ga_metrics,
)
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = algo_instance.init(
init_variables, population_size, subkey
)
else:
# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=0,
)
# Instantiate MAP-Elites
algo_instance = MAPElites(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
)
# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=(50, 50),
minval=-1,
maxval=1,
)
# Compute initial repertoire and emitter state
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = algo_instance.init(init_variables, centroids, subkey)
use_map_elites = True
if not use_map_elites:
algo_instance = GeneticAlgorithm(
scoring_function=scoring_function,
emitter=mixing_emitter,
metrics_function=default_ga_metrics,
)
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = algo_instance.init(
init_variables, population_size, subkey
)
else:
# Define a metrics function
metrics_function = functools.partial(
default_qd_metrics,
qd_offset=0,
)
# Instantiate MAP-Elites
algo_instance = MAPElites(
scoring_function=scoring_fn,
emitter=mixing_emitter,
metrics_function=metrics_function,
)
# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=(50, 50),
minval=-1,
maxval=1,
)
# Compute initial repertoire and emitter state
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics = algo_instance.init(init_variables, centroids, subkey)
Run the optimization loop¶
In [ ]:
Copied!
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
algo_instance.scan_update,
(repertoire, emitter_state, key),
(),
length=num_iterations,
)
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
algo_instance.scan_update,
(repertoire, emitter_state, key),
(),
length=num_iterations,
)
In [ ]:
Copied!
metrics["max_fitness"][-1]
metrics["max_fitness"][-1]
In [ ]:
Copied!
repertoire.fitnesses
repertoire.fitnesses
In [ ]:
Copied!
repertoire.descriptors
repertoire.descriptors
In [ ]:
Copied!
# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size
from qdax.utils.plotting import plot_map_elites_results
# create the plots and the grid
fig, axes = plot_map_elites_results(
env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=-1., max_descriptor=1.
)
# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size
from qdax.utils.plotting import plot_map_elites_results
# create the plots and the grid
fig, axes = plot_map_elites_results(
env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=-1., max_descriptor=1.
)
Play snake with the best policy¶
Retrieve one of the best policies from the repertoire and show how it does on the Snake environment.
In [ ]:
Copied!
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
In [ ]:
Copied!
print(
f"Best fitness in the repertoire: {best_fitness:.2f}\n",
f"Index in the repertoire of this individual: {best_idx}\n"
)
print(
f"Best fitness in the repertoire: {best_fitness:.2f}\n",
f"Index in the repertoire of this individual: {best_idx}\n"
)
In [ ]:
Copied!
my_params = jax.tree.map(
lambda x: x[best_idx],
repertoire.genotypes
)
my_params = jax.tree.map(
lambda x: x[best_idx],
repertoire.genotypes
)
In [ ]:
Copied!
init_state = jax.tree.map(
lambda x: x[0],
init_states
)
init_state = jax.tree.map(
lambda x: x[0],
init_states
)
In [ ]:
Copied!
init_timestep = jax.tree.map(
lambda x: x[0],
init_timesteps
)
init_timestep = jax.tree.map(
lambda x: x[0],
init_timesteps
)
In [ ]:
Copied!
state = jax.tree.map(lambda x: x.copy(), init_state)
timestep = jax.tree.map(lambda x: x.copy(), init_timestep)
for _ in range(100):
# (Optional) Render the env state
env.render(state)
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(my_params, network_input)
action = jnp.argmax(proba_action)
state, timestep = jax.jit(env.step)(state, action)
state = jax.tree.map(lambda x: x.copy(), init_state)
timestep = jax.tree.map(lambda x: x.copy(), init_timestep)
for _ in range(100):
# (Optional) Render the env state
env.render(state)
network_input = observation_processing(timestep.observation)
proba_action = policy_network.apply(my_params, network_input)
action = jnp.argmax(proba_action)
state, timestep = jax.jit(env.step)(state, action)