Training a population on Jumanji-Snake with QDax

This notebook shows how to use either MAP-Elites or a simple (non-QD) genetic algorithm to train a population of agents that play the game of Snake from Jumanji.

This notebook can be used as an inspiration to interact with other environments from Jumanji.

from functools import partial
from typing import Tuple, Type

import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
    import brax

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
    import chex

try:
    import jumanji
except:
    !pip install "jumanji==0.3.1"
    import jumanji

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax

import functools

import numpy as np

from qdax.baselines.genetic_algorithm import GeneticAlgorithm
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids


from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP

from qdax.tasks.jumanji_envs import jumanji_scoring_function

from qdax.core.emitters.mutation_operators import isoline_variation

from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics, default_qd_metrics

Define hyperparameters

seed = 0
policy_hidden_layer_sizes = (128, 128)
episode_length = 200
population_size = 100
batch_size = population_size

num_iterations = 5000

iso_sigma = 0.005
line_sigma = 0.05

Instantiate the snake environment

# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-v1')

# Reset your (jit-able) environment
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value()          # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)

Define the type of policy that will be used to solve the problem

# Init a random key
random_key = jax.random.PRNGKey(seed)

# get number of actions
num_actions = env.action_spec().maximum + 1

policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jax.nn.softmax,
)

Utils to interact with the environment

Define a way to process the observation and define a way to play a step in the environment, given the parameters of a policy_network.

def observation_processing(observation):
    network_input = jnp.ravel(observation)
    return network_input


def play_step_fn(
    env_state,
    timestep,
    policy_params,
    random_key,
):
    """Play an environment step and return the updated state and the transition.
    Everything is deterministic in this simple example.
    """

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(policy_params, network_input)

    action = jnp.argmax(proba_action)

    state_desc = None
    next_state, next_timestep = env.step(env_state, action)

    # next_state_desc=next_state.info["state_descriptor"]
    next_state_desc = None

    transition = QDTransition(
        obs=timestep.observation,
        next_obs=next_timestep.observation,
        rewards=next_timestep.reward,
        dones=jnp.where(next_timestep.last(), x=jnp.array(1), y=jnp.array(0)),
        actions=action,
        truncations=jnp.array(0),
        state_desc=state_desc,
        next_state_desc=next_state_desc,
    )

    return next_state, next_timestep, policy_params, random_key, transition

Init a population of policies

Also init init states and timesteps

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
obs_spec = env.observation_spec()
observation_size = np.prod(np.array(obs_spec.grid.shape + obs_spec.step_count.shape + obs_spec.action_mask.shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

# Create the initial environment states
random_key, subkey = jax.random.split(random_key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))

init_states, init_timesteps = reset_fn(keys)

Define a method to extract behavior descriptor when relevant

# Prepare the scoring function
def bd_extraction(data: QDTransition, mask: jnp.ndarray, linear_projection: jnp.ndarray) -> Descriptor:
    """Compute feet contact time proportion.

    This function suppose that state descriptor is the feet contact, as it
    just computes the mean of the state descriptors given.
    """

    # pre-process the observation
    observation = jax.vmap(jax.vmap(observation_processing))(data.obs)

    # get the mean
    mean_observation = jnp.mean(observation, axis=-2)

    # project those in [-1, 1]^2
    descriptors = jnp.tanh(mean_observation @ linear_projection.T)

    return descriptors

# create a random projection to a two dim space
random_key, subkey = jax.random.split(random_key)
linear_projection = jax.random.uniform(
    subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)

bd_extraction_fn = functools.partial(
    bd_extraction,
    linear_projection=linear_projection
)

# define the scoring function
scoring_fn = functools.partial(
    jumanji_scoring_function,
    init_states=init_states,
    init_timesteps=init_timesteps,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
    behavior_descriptor_extractor=bd_extraction_fn,
)

Define the scoring function

def scoring_function(
    genotypes: jnp.ndarray, random_key: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
    fitnesses, _, extra_scores, random_key = scoring_fn(genotypes, random_key)
    return fitnesses.reshape(-1, 1), extra_scores, random_key

Define the emitter used

# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None,
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size
)

Define the algorithm used and apply the initial step

One can either use a simple genetic algorithm or use MAP-Elites.

use_map_elites = True

if not use_map_elites:
    algo_instance = GeneticAlgorithm(
        scoring_function=scoring_function,
        emitter=mixing_emitter,
        metrics_function=default_ga_metrics,
    )

    repertoire, emitter_state, random_key = algo_instance.init(
        init_variables, population_size, random_key
    )

else:
    # Define a metrics function
    metrics_function = functools.partial(
        default_qd_metrics,
        qd_offset=0,
    )

    # Instantiate MAP-Elites
    algo_instance = MAPElites(
        scoring_function=scoring_fn,
        emitter=mixing_emitter,
        metrics_function=metrics_function,
    )

    # Compute the centroids
    centroids = compute_euclidean_centroids(
        grid_shape=(50, 50),
        minval=-1,
        maxval=1,
    )

    # Compute initial repertoire and emitter state
    repertoire, emitter_state, random_key = algo_instance.init(init_variables, centroids, random_key)

Run the optimization loop

%%time

# Run the algorithm
(repertoire, emitter_state, random_key,), metrics = jax.lax.scan(
    algo_instance.scan_update,
    (repertoire, emitter_state, random_key),
    (),
    length=num_iterations,
)
metrics["max_fitness"][-1]
repertoire.fitnesses
repertoire.descriptors
# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size

from qdax.utils.plotting import plot_map_elites_results

# create the plots and the grid
fig, axes = plot_map_elites_results(
    env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_bd=-1., max_bd=1.
)

Play snake with the best policy

Retrieve one of the best policies from the repertoire and show how it does on the Snake environment.

best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n",
    f"Index in the repertoire of this individual: {best_idx}\n"
)
my_params = jax.tree_util.tree_map(
    lambda x: x[best_idx],
    repertoire.genotypes
)
init_state = jax.tree_util.tree_map(
    lambda x: x[0],
    init_states
)
init_timestep = jax.tree_util.tree_map(
    lambda x: x[0],
    init_timesteps
)
state = jax.tree_util.tree_map(lambda x: x.copy(), init_state)
timestep = jax.tree_util.tree_map(lambda x: x.copy(), init_timestep)

for _ in range(100):
    # (Optional) Render the env state
    env.render(state)

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(my_params, network_input)

    action = jnp.argmax(proba_action)


    state, timestep = jax.jit(env.step)(state, action)