QDax docs
  • Home
  • Installation
  • Overview
  • Caveats

Guides

  • Contributing

Examples

  • Optimizing with MAP-Elites in JAX
  • Optimizing with PGAME in JAX
  • Optimizing with DCRL-ME in JAX
  • Optimizing with CMA-ME in JAX
  • Optimizing with QDPG in JAX
  • Optimizing with OMG-MEGA in JAX
  • Optimizing with CMA-MEGA in JAX
  • Optimizing multiple objectives with MOME in JAX
  • Optimizing with MEES in JAX
  • Training DIAYN with JAX
  • Training DADS with JAX
  • Training DIAYN SMERL with JAX
  • Optimizing with CMA-ES in JAX
  • Optimizing multiple objectives with NSGA2 & SPEA2 in JAX
  • Optimizing with AURORA in JAX
  • Optimizing with PGA-AURORA in JAX
    • Installation
    • Init environment, policy, population params, init states of the env
    • Define the way the policy interacts with the env
    • Define the scoring function and the way metrics are computed
    • Define the emitter
    • Instantiate and initialise the MAP Elites algorithm
    • Launch AURORA iterations
  • PBT
  • MAPElites PBT
  • Optimizing Uncertain Domains with ME-LS in JAX
  • Training a population on Jumanji-Snake with QDax

API documentation

  • Core
    • Core algorithms
      • MAP Elites
      • PGAME
      • DCRLME
      • QDPG
      • CMA ME
      • OMG MEGA
      • CMA MEGA
      • MOME
      • ME ES
      • AURORA
      • PGA AURORA
      • ME PBT
      • ME LS
    • Baseline algorithms
      • SMERL
      • DIAYN
      • DADS
      • SAC
      • TD3
      • Genetic Algorithm
      • NSGA2
      • SPEA2
      • PBT
      • CMAES
    • Containers
    • Emitters
    • Neuroevolution
  • Environments
  • Environments
  • Utils
QDax docs
  • Examples
  • Optimizing with PGA-AURORA in JAX
  • Edit on QDax

Open In Colab

Optimizing with PGA-AURORA in JAX¶

This notebook shows how to use QDax to find diverse and performing controllers in MDPs with PGA-AURORA. It can be run locally or on Google Colab. We recommend to use a GPU. This notebook will show:

  • how to define the problem
  • how to create an emitter
  • how to create an AURORA instance and mix it with the right emitter to define PGA-AURORA
  • which functions must be defined before training
  • how to launch a certain number of training steps
  • how to visualise the optimization process

Installation¶

You will need Python 3.11 or later, and a working JAX installation. For example, you can install JAX with:

In [ ]:
Copied!
%pip install -U "jax[cuda]"
%pip install -U "jax[cuda]"

Then, install QDax from PyPI:

In [ ]:
Copied!
%pip install -U "qdax[examples]"
%pip install -U "qdax[examples]"
In [ ]:
Copied!
import os

from IPython.display import clear_output
import functools
from typing import Dict, Any

import jax
import jax.numpy as jnp

from qdax.core.aurora import AURORA
from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire
import qdax.tasks.brax as environments
from qdax.tasks.brax.env_creators import (
    create_default_brax_task_components,
    get_aurora_scoring_fn,
)
from qdax.tasks.brax.descriptor_extractors import (
    AuroraExtraInfoNormalization,
    get_aurora_encoding,
)
from qdax.core.neuroevolution.buffers.buffer import QDTransition
from qdax.core.neuroevolution.networks.networks import MLP
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter

from qdax.custom_types import Observation
from qdax.utils import train_seq2seq
import os from IPython.display import clear_output import functools from typing import Dict, Any import jax import jax.numpy as jnp from qdax.core.aurora import AURORA from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire import qdax.tasks.brax as environments from qdax.tasks.brax.env_creators import ( create_default_brax_task_components, get_aurora_scoring_fn, ) from qdax.tasks.brax.descriptor_extractors import ( AuroraExtraInfoNormalization, get_aurora_encoding, ) from qdax.core.neuroevolution.buffers.buffer import QDTransition from qdax.core.neuroevolution.networks.networks import MLP from qdax.core.emitters.mutation_operators import isoline_variation from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter from qdax.custom_types import Observation from qdax.utils import train_seq2seq
In [ ]:
Copied!
#@title QD Training Definitions Fields
#@markdown ---
env_batch_size = 100 #@param {type:"number"}
env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni']
episode_length = 250 #@param {type:"integer"}
max_iterations = 50 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"}
iso_sigma = 0.005 #@param {type:"number"}
line_sigma = 0.05 #@param {type:"number"}
num_init_cvt_samples = 50000 #@param {type:"integer"}
num_centroids = 1024 #@param {type:"integer"}
min_descriptor = 0. #@param {type:"number"}
max_descriptor = 1.0 #@param {type:"number"}

lstm_batch_size = 128 #@param {type:"integer"}

observation_option = "no_sd" #@param['no_sd', 'only_sd', 'full']
hidden_size = 5 #@param {type:"integer"}
l_value_init = 0.2 #@param {type:"number"}

traj_sampling_freq = 10 #@param {type:"integer"}
max_observation_size = 25 #@param {type:"integer"}
prior_descriptor_dim = 2 #@param {type:"integer"}

proportion_mutation_ga = 0.5 #@param {type:"number"}

# TD3 params
replay_buffer_size = 1000000 #@param {type:"number"}
critic_hidden_layer_size = (256, 256) #@param {type:"raw"}
critic_learning_rate = 3e-4 #@param {type:"number"}
greedy_learning_rate = 3e-4 #@param {type:"number"}
policy_learning_rate = 1e-3 #@param {type:"number"}
noise_clip = 0.5 #@param {type:"number"}
policy_noise = 0.2 #@param {type:"number"}
discount = 0.99 #@param {type:"number"}
reward_scaling = 1.0 #@param {type:"number"}
transitions_batch_size = 256 #@param {type:"number"}
soft_tau_update = 0.005 #@param {type:"number"}
num_critic_training_steps = 300 #@param {type:"number"}
num_pg_training_steps = 100 #@param {type:"number"}
policy_delay = 2 #@param {type:"number"}

log_freq = 5 #@param {type:"integer"}

# Custom observations key that will be used to store the observations in the
# extra_scores of the repertoire
aurora_observations_key = "observations"

#@markdown ---
#@title QD Training Definitions Fields #@markdown --- env_batch_size = 100 #@param {type:"number"} env_name = 'walker2d_uni'#@param['ant_uni', 'hopper_uni', 'walker2d_uni', 'halfcheetah_uni', 'humanoid_uni', 'ant_omni', 'humanoid_omni'] episode_length = 250 #@param {type:"integer"} max_iterations = 50 #@param {type:"integer"} seed = 42 #@param {type:"integer"} policy_hidden_layer_sizes = (64, 64) #@param {type:"raw"} iso_sigma = 0.005 #@param {type:"number"} line_sigma = 0.05 #@param {type:"number"} num_init_cvt_samples = 50000 #@param {type:"integer"} num_centroids = 1024 #@param {type:"integer"} min_descriptor = 0. #@param {type:"number"} max_descriptor = 1.0 #@param {type:"number"} lstm_batch_size = 128 #@param {type:"integer"} observation_option = "no_sd" #@param['no_sd', 'only_sd', 'full'] hidden_size = 5 #@param {type:"integer"} l_value_init = 0.2 #@param {type:"number"} traj_sampling_freq = 10 #@param {type:"integer"} max_observation_size = 25 #@param {type:"integer"} prior_descriptor_dim = 2 #@param {type:"integer"} proportion_mutation_ga = 0.5 #@param {type:"number"} # TD3 params replay_buffer_size = 1000000 #@param {type:"number"} critic_hidden_layer_size = (256, 256) #@param {type:"raw"} critic_learning_rate = 3e-4 #@param {type:"number"} greedy_learning_rate = 3e-4 #@param {type:"number"} policy_learning_rate = 1e-3 #@param {type:"number"} noise_clip = 0.5 #@param {type:"number"} policy_noise = 0.2 #@param {type:"number"} discount = 0.99 #@param {type:"number"} reward_scaling = 1.0 #@param {type:"number"} transitions_batch_size = 256 #@param {type:"number"} soft_tau_update = 0.005 #@param {type:"number"} num_critic_training_steps = 300 #@param {type:"number"} num_pg_training_steps = 100 #@param {type:"number"} policy_delay = 2 #@param {type:"number"} log_freq = 5 #@param {type:"integer"} # Custom observations key that will be used to store the observations in the # extra_scores of the repertoire aurora_observations_key = "observations" #@markdown ---

Init environment, policy, population params, init states of the env¶

Define the environment in which the policies will be trained. In this notebook, we focus on controllers learning to move a robot in a physical simulation. We also define the shared policy, that every individual in the population will use. Once the policy is defined, all individuals are defined by their parameters, that corresponds to their genotype.

In [ ]:
Copied!
# Init environment
env = environments.create(env_name, episode_length=episode_length)

# Init a random key
key = jax.random.key(seed)

# Init policy network
policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)
policy_network = MLP(
    layer_sizes=policy_layer_sizes,
    kernel_init=jax.nn.initializers.lecun_uniform(),
    final_activation=jnp.tanh,
)

# Init population of controllers
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, num=env_batch_size)
fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)


# Create the initial environment states
key, subkey = jax.random.split(key)
keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0)
reset_fn = jax.jit(jax.vmap(env.reset))
init_states = reset_fn(keys)
# Init environment env = environments.create(env_name, episode_length=episode_length) # Init a random key key = jax.random.key(seed) # Init policy network policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,) policy_network = MLP( layer_sizes=policy_layer_sizes, kernel_init=jax.nn.initializers.lecun_uniform(), final_activation=jnp.tanh, ) # Init population of controllers key, subkey = jax.random.split(key) keys = jax.random.split(subkey, num=env_batch_size) fake_batch = jnp.zeros(shape=(env_batch_size, env.observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) # Create the initial environment states key, subkey = jax.random.split(key) keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=env_batch_size, axis=0) reset_fn = jax.jit(jax.vmap(env.reset)) init_states = reset_fn(keys)

Define the way the policy interacts with the env¶

Now that the environment and policy has been defined, it is necessary to define a function that describes how the policy must be used to interact with the environment and to store transition data.

In [ ]:
Copied!
# Define the function to play a step with the policy in the environment
def play_step_fn(
    env_state,
    policy_params,
    key,
):
    """
    Play an environment step and return the updated state and the transition.
    """

    actions = policy_network.apply(policy_params, env_state.obs)

    state_desc = env_state.info["state_descriptor"]
    next_state = env.step(env_state, actions)

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_state.obs,
        rewards=next_state.reward,
        dones=next_state.done,
        actions=actions,
        truncations=next_state.info["truncation"],
        state_desc=state_desc,
        next_state_desc=next_state.info["state_descriptor"],
    )

    return next_state, policy_params, key, transition
# Define the function to play a step with the policy in the environment def play_step_fn( env_state, policy_params, key, ): """ Play an environment step and return the updated state and the transition. """ actions = policy_network.apply(policy_params, env_state.obs) state_desc = env_state.info["state_descriptor"] next_state = env.step(env_state, actions) transition = QDTransition( obs=env_state.obs, next_obs=next_state.obs, rewards=next_state.reward, dones=next_state.done, actions=actions, truncations=next_state.info["truncation"], state_desc=state_desc, next_state_desc=next_state.info["state_descriptor"], ) return next_state, policy_params, key, transition

Define the scoring function and the way metrics are computed¶

The scoring function is used in the evaluation step to determine the fitness and descriptor of each individual.

In [ ]:
Copied!
# Prepare the scoring function
key, subkey = jax.random.split(key)
env, policy_network, scoring_fn = create_default_brax_task_components(
    env_name=env_name,
    key=subkey,
)

def observation_extractor_fn(
    data: QDTransition,
) -> Observation:
    """Extract observation from the state."""
    state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size]

    # add the x/y position - (batch_size, traj_length, 2)
    state_desc = data.state_desc[:, ::traj_sampling_freq]

    if observation_option == "full":
        observations = jnp.concatenate([state_desc, state_obs], axis=-1)
    elif observation_option == "no_sd":
        observations = state_obs
    elif observation_option == "only_sd":
        observations = state_desc
    else:
        raise ValueError("Unknown observation option.")

    return observations

# Prepare the scoring function
aurora_scoring_fn = get_aurora_scoring_fn(
    scoring_fn=scoring_fn,
    observation_extractor_fn=observation_extractor_fn,
    observations_key=aurora_observations_key,
)

# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict:

    # Get metrics
    grid_empty = repertoire.fitnesses == -jnp.inf
    qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty)
    # Add offset for positive qd_score
    qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty)
    coverage = 100 * jnp.mean(1.0 - grid_empty)
    max_fitness = jnp.max(repertoire.fitnesses)

    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}
# Prepare the scoring function key, subkey = jax.random.split(key) env, policy_network, scoring_fn = create_default_brax_task_components( env_name=env_name, key=subkey, ) def observation_extractor_fn( data: QDTransition, ) -> Observation: """Extract observation from the state.""" state_obs = data.obs[:, ::traj_sampling_freq, :max_observation_size] # add the x/y position - (batch_size, traj_length, 2) state_desc = data.state_desc[:, ::traj_sampling_freq] if observation_option == "full": observations = jnp.concatenate([state_desc, state_obs], axis=-1) elif observation_option == "no_sd": observations = state_obs elif observation_option == "only_sd": observations = state_desc else: raise ValueError("Unknown observation option.") return observations # Prepare the scoring function aurora_scoring_fn = get_aurora_scoring_fn( scoring_fn=scoring_fn, observation_extractor_fn=observation_extractor_fn, observations_key=aurora_observations_key, ) # Get minimum reward value to make sure qd_score are positive reward_offset = environments.reward_offset[env_name] # Define a metrics function def metrics_fn(repertoire: UnstructuredRepertoire) -> Dict: # Get metrics grid_empty = repertoire.fitnesses == -jnp.inf qd_score = jnp.sum(repertoire.fitnesses, where=~grid_empty) # Add offset for positive qd_score qd_score += reward_offset * episode_length * jnp.sum(1.0 - grid_empty) coverage = 100 * jnp.mean(1.0 - grid_empty) max_fitness = jnp.max(repertoire.fitnesses) return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

Define the emitter¶

The emitter is used to evolve the population at each mutation step.

In [ ]:
Copied!
# Define the PG-emitter config
pga_emitter_config = PGAMEConfig(
    env_batch_size=env_batch_size,
    batch_size=transitions_batch_size,
    proportion_mutation_ga=proportion_mutation_ga,
    critic_hidden_layer_size=critic_hidden_layer_size,
    critic_learning_rate=critic_learning_rate,
    greedy_learning_rate=greedy_learning_rate,
    policy_learning_rate=policy_learning_rate,
    noise_clip=noise_clip,
    policy_noise=policy_noise,
    discount=discount,
    reward_scaling=reward_scaling,
    replay_buffer_size=replay_buffer_size,
    soft_tau_update=soft_tau_update,
    num_critic_training_steps=num_critic_training_steps,
    num_pg_training_steps=num_pg_training_steps,
    policy_delay=policy_delay,
)
# Define the PG-emitter config pga_emitter_config = PGAMEConfig( env_batch_size=env_batch_size, batch_size=transitions_batch_size, proportion_mutation_ga=proportion_mutation_ga, critic_hidden_layer_size=critic_hidden_layer_size, critic_learning_rate=critic_learning_rate, greedy_learning_rate=greedy_learning_rate, policy_learning_rate=policy_learning_rate, noise_clip=noise_clip, policy_noise=policy_noise, discount=discount, reward_scaling=reward_scaling, replay_buffer_size=replay_buffer_size, soft_tau_update=soft_tau_update, num_critic_training_steps=num_critic_training_steps, num_pg_training_steps=num_pg_training_steps, policy_delay=policy_delay, )
In [ ]:
Copied!
# Get the emitter
variation_fn = functools.partial(
    isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma
)

pg_emitter = PGAMEEmitter(
    config=pga_emitter_config,
    policy_network=policy_network,
    env=env,
    variation_fn=variation_fn,
)
# Get the emitter variation_fn = functools.partial( isoline_variation, iso_sigma=iso_sigma, line_sigma=line_sigma ) pg_emitter = PGAMEEmitter( config=pga_emitter_config, policy_network=policy_network, env=env, variation_fn=variation_fn, )

Instantiate and initialise the MAP Elites algorithm¶

In [ ]:
Copied!
aurora_dims = hidden_size
centroids = jnp.zeros(shape=(num_centroids, aurora_dims))

@jax.jit
def update_scan_fn(carry: Any, _: Any) -> Any:
    """Scan the update function."""
    (
        repertoire,
        emitter_state,
        key,
        aurora_extra_info
    ) = carry

    # update
    key, subkey = jax.random.split(key)
    repertoire, emitter_state, metrics = aurora.update(
        repertoire,
        emitter_state,
        subkey,
        aurora_extra_info=aurora_extra_info,
    )

    return (
        (repertoire, emitter_state, key, aurora_extra_info),
        metrics,
    )

# Init algorithm
# AutoEncoder Params and INIT
obs_dim = jnp.minimum(env.observation_size, max_observation_size)
if observation_option == "full":
    observations_dims = (
        episode_length // traj_sampling_freq,
        obs_dim + prior_descriptor_dim,
    )
elif observation_option == "no_sd":
    observations_dims = (
        episode_length // traj_sampling_freq,
        obs_dim,
    )
elif observation_option == "only_sd":
    observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim)
else:
    ValueError("The chosen option is not correct.")

# Define the seq2seq model
model = train_seq2seq.get_model(
    observations_dims[-1], True, hidden_size=hidden_size
)

# Init the model params
key, subkey = jax.random.split(key)
model_params = train_seq2seq.get_initial_params(
    model, subkey, (1, *observations_dims)
)

# Define the encoder function
encoder_fn = jax.jit(
    functools.partial(
        get_aurora_encoding,
        model=model,
    )
)

# Define the training function
train_fn = functools.partial(
    train_seq2seq.lstm_ae_train,
    model=model,
    batch_size=lstm_batch_size,
)

# Instantiate AURORA
aurora = AURORA(
    scoring_function=aurora_scoring_fn,
    emitter=pg_emitter,
    metrics_function=metrics_fn,
    encoder_function=encoder_fn,
    training_function=train_fn,
    observations_key=aurora_observations_key,
)

# init the model params
key, subkey = jax.random.split(key)
model_params = train_seq2seq.get_initial_params(
    model, subkey, (1, *observations_dims)
)

# define arbitrary observation's mean/std
mean_observations = jnp.zeros(observations_dims[-1])
std_observations = jnp.ones(observations_dims[-1])

# init all the information needed by AURORA to compute encodings
aurora_extra_info = AuroraExtraInfoNormalization.create(
    model_params,
    mean_observations,
    std_observations,
)

# init step of the aurora algorithm
key, subkey = jax.random.split(key)
repertoire, emitter_state, init_metrics, aurora_extra_info = aurora.init(
    init_variables,
    aurora_extra_info,
    jnp.asarray(l_value_init),
    max_observation_size,
    subkey,
)

# initializing means and stds and AURORA
key, subkey = jax.random.split(key)
repertoire, aurora_extra_info = aurora.train(
    repertoire, model_params, iteration=0, key=subkey
)

# design aurora's schedule
default_update_base = 10
update_base = int(jnp.ceil(default_update_base / log_freq))
schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))
aurora_dims = hidden_size centroids = jnp.zeros(shape=(num_centroids, aurora_dims)) @jax.jit def update_scan_fn(carry: Any, _: Any) -> Any: """Scan the update function.""" ( repertoire, emitter_state, key, aurora_extra_info ) = carry # update key, subkey = jax.random.split(key) repertoire, emitter_state, metrics = aurora.update( repertoire, emitter_state, subkey, aurora_extra_info=aurora_extra_info, ) return ( (repertoire, emitter_state, key, aurora_extra_info), metrics, ) # Init algorithm # AutoEncoder Params and INIT obs_dim = jnp.minimum(env.observation_size, max_observation_size) if observation_option == "full": observations_dims = ( episode_length // traj_sampling_freq, obs_dim + prior_descriptor_dim, ) elif observation_option == "no_sd": observations_dims = ( episode_length // traj_sampling_freq, obs_dim, ) elif observation_option == "only_sd": observations_dims = (episode_length // traj_sampling_freq, prior_descriptor_dim) else: ValueError("The chosen option is not correct.") # Define the seq2seq model model = train_seq2seq.get_model( observations_dims[-1], True, hidden_size=hidden_size ) # Init the model params key, subkey = jax.random.split(key) model_params = train_seq2seq.get_initial_params( model, subkey, (1, *observations_dims) ) # Define the encoder function encoder_fn = jax.jit( functools.partial( get_aurora_encoding, model=model, ) ) # Define the training function train_fn = functools.partial( train_seq2seq.lstm_ae_train, model=model, batch_size=lstm_batch_size, ) # Instantiate AURORA aurora = AURORA( scoring_function=aurora_scoring_fn, emitter=pg_emitter, metrics_function=metrics_fn, encoder_function=encoder_fn, training_function=train_fn, observations_key=aurora_observations_key, ) # init the model params key, subkey = jax.random.split(key) model_params = train_seq2seq.get_initial_params( model, subkey, (1, *observations_dims) ) # define arbitrary observation's mean/std mean_observations = jnp.zeros(observations_dims[-1]) std_observations = jnp.ones(observations_dims[-1]) # init all the information needed by AURORA to compute encodings aurora_extra_info = AuroraExtraInfoNormalization.create( model_params, mean_observations, std_observations, ) # init step of the aurora algorithm key, subkey = jax.random.split(key) repertoire, emitter_state, init_metrics, aurora_extra_info = aurora.init( init_variables, aurora_extra_info, jnp.asarray(l_value_init), max_observation_size, subkey, ) # initializing means and stds and AURORA key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( repertoire, model_params, iteration=0, key=subkey ) # design aurora's schedule default_update_base = 10 update_base = int(jnp.ceil(default_update_base / log_freq)) schedules = jnp.cumsum(jnp.arange(update_base, 1000, update_base))

Launch AURORA iterations¶

In [ ]:
Copied!
current_step_estimation = 0
num_iterations = 0

# Main loop
n_target = 1024

previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target
container_size_control_fn = jax.jit(aurora.container_size_control)

iteration = 0
while iteration < max_iterations:

    (
        (repertoire, emitter_state, key, aurora_extra_info),
        metrics,
    ) = jax.lax.scan(
        update_scan_fn,
        (repertoire, emitter_state, key, aurora_extra_info),
        (),
        length=log_freq,
    )

    num_iterations = iteration * log_freq

    # update nb steps estimation
    current_step_estimation += env_batch_size * episode_length * log_freq

    # autoencoder steps and CVC
    if (iteration + 1) in schedules:
        # train the autoencoder
        key, subkey = jax.random.split(key)
        repertoire, aurora_extra_info = aurora.train(
            repertoire, model_params, iteration, subkey
        )

    elif iteration % 2 == 0:
        repertoire, previous_error = container_size_control_fn(
            repertoire,
            target_size=n_target,
            previous_error=previous_error,
        )


    iteration += 1
current_step_estimation = 0 num_iterations = 0 # Main loop n_target = 1024 previous_error = jnp.sum(repertoire.fitnesses != -jnp.inf) - n_target container_size_control_fn = jax.jit(aurora.container_size_control) iteration = 0 while iteration < max_iterations: ( (repertoire, emitter_state, key, aurora_extra_info), metrics, ) = jax.lax.scan( update_scan_fn, (repertoire, emitter_state, key, aurora_extra_info), (), length=log_freq, ) num_iterations = iteration * log_freq # update nb steps estimation current_step_estimation += env_batch_size * episode_length * log_freq # autoencoder steps and CVC if (iteration + 1) in schedules: # train the autoencoder key, subkey = jax.random.split(key) repertoire, aurora_extra_info = aurora.train( repertoire, model_params, iteration, subkey ) elif iteration % 2 == 0: repertoire, previous_error = container_size_control_fn( repertoire, target_size=n_target, previous_error=previous_error, ) iteration += 1
In [ ]:
Copied!
for k, v in metrics.items():
    print(k, " - ", v[-1])
for k, v in metrics.items(): print(k, " - ", v[-1])
Previous Next

Built with MkDocs using a theme provided by Read the Docs.