QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax
    • Installation
    • Define hyperparameters
    • Instantiate the snake environment
    • Define the type of policy that will be used to solve the problem
    • Utils to interact with the environment
    • Init a population of policies
    • Define a method to extract descriptor when relevant
    • Define the scoring function
    • Define the emitter used
    • Define the algorithm used and apply the initial step
    • Run the optimization loop
    • Play snake with the best policy

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Training a population on Jumanji-Snake with QDax
  • Edit on QDax

Training a population on Jumanji-Snake with QDax¶

This notebook shows how to use either MAP-Elites or a simple (non-QD) genetic algorithm to train a population of agents that play the game of Snake from Jumanji.

This notebook can be used as an inspiration to interact with other environments from Jumanji.

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
from functools import partial
from typing import Tuple, Type

import jax
import jax.numpy as jnp

import functools

import jumanji

import numpy as np

from qdax.baselines.genetic_algorithm import GeneticAlgorithm
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids


from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP

from qdax.tasks.jumanji_envs import jumanji_scoring_function

from qdax.core.emitters.mutation_operators import isoline_variation

from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor
from qdax.utils.metrics import default_ga_metrics, default_qd_metrics
from functools import partial from typing import Tuple, Type import jax import jax.numpy as jnp import functools import jumanji import numpy as np from qdax.baselines.genetic_algorithm import GeneticAlgorithm from qdax.core.map_elites import MAPElites from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.tasks.jumanji_envs import jumanji_scoring_function from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.custom_types import ExtraScores, Fitness, RNGKey, Descriptor from qdax.utils.metrics import default_ga_metrics, default_qd_metrics

Define hyperparameters¶

In [ ]:
Copied!
seed = 0
policy_hidden_layer_sizes = (128, 128)
episode_length = 200
population_size = 100
batch_size = population_size

num_iterations = 1000

iso_sigma = 0.005
line_sigma = 0.05
seed = 0 policy_hidden_layer_sizes = (128, 128) episode_length = 200 population_size = 100 batch_size = population_size num_iterations = 1000 iso_sigma = 0.005 line_sigma = 0.05

Instantiate the snake environment¶

In [ ]:
Copied!
# Instantiate a Jumanji environment using the registry
env = jumanji.make('Snake-v1')

# Reset your (jit-able) environment
key = jax.random.key(seed)

key, subkey = jax.random.split(key)
state, timestep = jax.jit(env.reset)(subkey)

# Interact with the (jit-able) environment
action = env.action_spec().generate_value()          # Action selection (dummy value here)
state, timestep = jax.jit(env.step)(state, action)
# Instantiate a Jumanji environment using the registry env = jumanji.make('Snake-v1') # Reset your (jit-able) environment key = jax.random.key(seed) key, subkey = jax.random.split(key) state, timestep = jax.jit(env.reset)(subkey) # Interact with the (jit-able) environment action = env.action_spec().generate_value() # Action selection (dummy value here) state, timestep = jax.jit(env.step)(state, action)

Define the type of policy that will be used to solve the problem¶

In [ ]:
Copied!
# Get number of actions
num_actions = env.action_spec().maximum + 1

policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jax.nn.softmax,
)
# Get number of actions num_actions = env.action_spec().maximum + 1 policy_layer_sizes = policy_hidden_layer_sizes + (num_actions,) policy_network = MLP( layer_sizes=policy_layer_sizes, kernel_init=jax.nn.initializers.lecun_uniform(), final_activation=jax.nn.softmax, )

Utils to interact with the environment¶

Define a way to process the observation and define a way to play a step in the environment, given the parameters of a policy_network.

In [ ]:
Copied!
def observation_processing(observation):
    network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()])
    return network_input


def play_step_fn(
    env_state,
    timestep,
    policy_params,
    key,
):
    """Play an environment step and return the updated state and the transition.
    Everything is deterministic in this simple example.
    """

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(policy_params, network_input)

    action = jnp.argmax(proba_action)

    state_desc = None
    next_state, next_timestep = env.step(env_state, action)

    # next_state_desc=next_state.info["state_descriptor"]
    next_state_desc = None

    transition = QDTransition(
        obs=timestep.observation,
        next_obs=next_timestep.observation,
        rewards=next_timestep.reward,
        dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)),
        actions=action,
        truncations=jnp.array(0),
        state_desc=state_desc,
        next_state_desc=next_state_desc,
    )

    return next_state, next_timestep, policy_params, key, transition
def observation_processing(observation): network_input = jnp.concatenate([jnp.ravel(observation.grid), jnp.array([observation.step_count]), observation.action_mask.ravel()]) return network_input def play_step_fn( env_state, timestep, policy_params, key, ): """Play an environment step and return the updated state and the transition. Everything is deterministic in this simple example. """ network_input = observation_processing(timestep.observation) proba_action = policy_network.apply(policy_params, network_input) action = jnp.argmax(proba_action) state_desc = None next_state, next_timestep = env.step(env_state, action) # next_state_desc=next_state.info["state_descriptor"] next_state_desc = None transition = QDTransition( obs=timestep.observation, next_obs=next_timestep.observation, rewards=next_timestep.reward, dones=jnp.where(next_timestep.last(), jnp.array(1), jnp.array(0)), actions=action, truncations=jnp.array(0), state_desc=state_desc, next_state_desc=next_state_desc, ) return next_state, next_timestep, policy_params, key, transition

Init a population of policies¶

Also init init states and timesteps

In [ ]:
Copied!
# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=batch_size)

# compute observation size from observation spec
obs_spec = env.observation_spec()
observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape))

fake_batch = jnp.zeros(shape=(batch_size, observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

# Create the initial environment states
key, subkey = jax.random.split(key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))

init_states, init_timesteps = reset_fn(keys)
# Init population of controllers key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec obs_spec = env.observation_spec() observation_size = int(np.prod(obs_spec.grid.shape) + np.prod(obs_spec.step_count.shape) + np.prod(obs_spec.action_mask.shape)) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) # Create the initial environment states key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states, init_timesteps = reset_fn(keys)

Define a method to extract descriptor when relevant¶

In [ ]:
Copied!
# Prepare the scoring function
def descriptor_extraction(data: QDTransition, mask: jax.Arraylinear_projection: jajax.Array Descriptor:
    """Compute feet contact time proportion.

    This function suppose that state descriptor is the feet contact, as it
    just computes the mean of the state descriptors given.
    """

    # pre-process the observation
    observation = jax.vmap(jax.vmap(observation_processing))(data.obs)

    # get the mean
    mean_observation = jnp.mean(observation, axis=-2)

    # project those in [-1, 1]^2
    descriptors = jnp.tanh(mean_observation @ linear_projection.T)

    return descriptors

# create a random projection to a two dim space
key, subkey = jax.random.split(key)
linear_projection = jax.random.uniform(
    subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32
)

descriptor_extraction_fn = functools.partial(
    descriptor_extraction,
    linear_projection=linear_projection
)

# define the scoring function
scoring_fn = functools.partial(
    jumanji_scoring_function,
    init_states=init_states,
    init_timesteps=init_timesteps,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
    descriptor_extractor=descriptor_extraction_fn,
)
# Prepare the scoring function def descriptor_extraction(data: QDTransition, mask: jax.Arraylinear_projection: jajax.Array Descriptor: """Compute feet contact time proportion. This function suppose that state descriptor is the feet contact, as it just computes the mean of the state descriptors given. """ # pre-process the observation observation = jax.vmap(jax.vmap(observation_processing))(data.obs) # get the mean mean_observation = jnp.mean(observation, axis=-2) # project those in [-1, 1]^2 descriptors = jnp.tanh(mean_observation @ linear_projection.T) return descriptors # create a random projection to a two dim space key, subkey = jax.random.split(key) linear_projection = jax.random.uniform( subkey, (2, observation_size), minval=-1, maxval=1, dtype=jnp.float32 ) descriptor_extraction_fn = functools.partial( descriptor_extraction, linear_projection=linear_projection ) # define the scoring function scoring_fn = functools.partial( jumanji_scoring_function, init_states=init_states, init_timesteps=init_timesteps, episode_length=episode_length, play_step_fn=play_step_fn, descriptor_extractor=descriptor_extraction_fn, )

Define the scoring function¶

In [ ]:
Copied!
def scoring_function(
    genotypes: jax.Arraykey: RNGKey
) -> Tuple[Fitness, ExtraScores, RNGKey]:
    fitnesses, _, extra_scores = scoring_fn(genotypes, key)
    return fitnesses.reshape(-1, 1), extra_scores
def scoring_function( genotypes: jax.Arraykey: RNGKey ) -> Tuple[Fitness, ExtraScores, RNGKey]: fitnesses, _, extra_scores = scoring_fn(genotypes, key) return fitnesses.reshape(-1, 1), extra_scores

Define the emitter used¶

In [ ]:
Copied!
# Define emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)
mixing_emitter = MixingEmitter(
    mutation_fn=None,
    variation_fn=variation_fn,
    variation_percentage=1.0,
    batch_size=batch_size
)
# Define emitter variation_fn = functools.partial( isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma ) mixing_emitter = MixingEmitter( mutation_fn=None, variation_fn=variation_fn, variation_percentage=1.0, batch_size=batch_size )

Define the algorithm used and apply the initial step¶

One can either use a simple genetic algorithm or use MAP-Elites.

In [ ]:
Copied!
use_map_elites = True

if not use_map_elites:
    algo_instance = GeneticAlgorithm(
        scoring_function=scoring_function,
        emitter=mixing_emitter,
        metrics_function=default_ga_metrics,
    )

    key, subkey = jax.random.split(key)
    repertoire, emitter_state, init_metrics = algo_instance.init(
        init_variables, population_size, subkey
    )

else:
    # Define a metrics function
    metrics_function = functools.partial(
        default_qd_metrics,
        qd_offset=0,
    )

    # Instantiate MAP-Elites
    algo_instance = MAPElites(
        scoring_function=scoring_fn,
        emitter=mixing_emitter,
        metrics_function=metrics_function,
    )

    # Compute the centroids
    centroids = compute_euclidean_centroids(
        grid_shape=(50, 50),
        minval=-1,
        maxval=1,
    )

    # Compute initial repertoire and emitter state
    key, subkey = jax.random.split(key)
    repertoire, emitter_state, init_metrics = algo_instance.init(init_variables, centroids, subkey)
use_map_elites = True if not use_map_elites: algo_instance = GeneticAlgorithm( scoring_function=scoring_function, emitter=mixing_emitter, metrics_function=default_ga_metrics, ) key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = algo_instance.init( init_variables, population_size, subkey ) else: # Define a metrics function metrics_function = functools.partial( default_qd_metrics, qd_offset=0, ) # Instantiate MAP-Elites algo_instance = MAPElites( scoring_function=scoring_fn, emitter=mixing_emitter, metrics_function=metrics_function, ) # Compute the centroids centroids = compute_euclidean_centroids( grid_shape=(50, 50), minval=-1, maxval=1, ) # Compute initial repertoire and emitter state key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics = algo_instance.init(init_variables, centroids, subkey)

Run the optimization loop¶

In [ ]:
Copied!
# Run the algorithm
(repertoire, emitter_state, key,), metrics = jax.lax.scan(
    algo_instance.scan_update,
    (repertoire, emitter_state, key),
    (),
    length=num_iterations,
)
# Run the algorithm (repertoire, emitter_state, key,), metrics = jax.lax.scan( algo_instance.scan_update, (repertoire, emitter_state, key), (), length=num_iterations, )
In [ ]:
Copied!
metrics["max_fitness"][-1]
metrics["max_fitness"][-1]
In [ ]:
Copied!
repertoire.fitnesses
repertoire.fitnesses
In [ ]:
Copied!
repertoire.descriptors
repertoire.descriptors
In [ ]:
Copied!
# create the x-axis array
env_steps = jnp.arange(num_iterations) * episode_length * batch_size

from qdax.utils.plotting import plot_map_elites_results

# create the plots and the grid
fig, axes = plot_map_elites_results(
    env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=-1., max_descriptor=1.
)
# create the x-axis array env_steps = jnp.arange(num_iterations) * episode_length * batch_size from qdax.utils.plotting import plot_map_elites_results # create the plots and the grid fig, axes = plot_map_elites_results( env_steps=env_steps, metrics=metrics, repertoire=repertoire, min_descriptor=-1., max_descriptor=1. )

Play snake with the best policy¶

Retrieve one of the best policies from the repertoire and show how it does on the Snake environment.

In [ ]:
Copied!
best_idx = jnp.argmax(repertoire.fitnesses)
best_fitness = jnp.max(repertoire.fitnesses)
best_idx = jnp.argmax(repertoire.fitnesses) best_fitness = jnp.max(repertoire.fitnesses)
In [ ]:
Copied!
print(
    f"Best fitness in the repertoire: {best_fitness:.2f}\n",
    f"Index in the repertoire of this individual: {best_idx}\n"
)
print( f"Best fitness in the repertoire: {best_fitness:.2f}\n", f"Index in the repertoire of this individual: {best_idx}\n" )
In [ ]:
Copied!
my_params = jax.tree.map(
    lambda x: x[best_idx],
    repertoire.genotypes
)
my_params = jax.tree.map( lambda x: x[best_idx], repertoire.genotypes )
In [ ]:
Copied!
init_state = jax.tree.map(
    lambda x: x[0],
    init_states
)
init_state = jax.tree.map( lambda x: x[0], init_states )
In [ ]:
Copied!
init_timestep = jax.tree.map(
    lambda x: x[0],
    init_timesteps
)
init_timestep = jax.tree.map( lambda x: x[0], init_timesteps )
In [ ]:
Copied!
state = jax.tree.map(lambda x: x.copy(), init_state)
timestep = jax.tree.map(lambda x: x.copy(), init_timestep)

for _ in range(100):
    # (Optional) Render the env state
    env.render(state)

    network_input = observation_processing(timestep.observation)

    proba_action = policy_network.apply(my_params, network_input)

    action = jnp.argmax(proba_action)


    state, timestep = jax.jit(env.step)(state, action)
state = jax.tree.map(lambda x: x.copy(), init_state) timestep = jax.tree.map(lambda x: x.copy(), init_timestep) for _ in range(100): # (Optional) Render the env state env.render(state) network_input = observation_processing(timestep.observation) proba_action = policy_network.apply(my_params, network_input) action = jnp.argmax(proba_action) state, timestep = jax.jit(env.step)(state, action)
Previous Next

Built with MkDocs using a theme provided by Read the Docs.