import functools
import time
from typing import Optional, Tuple

import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
    import brax

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
    import chex

try:
    import jumanji
except:
    !pip install "jumanji==0.3.1"
    import jumanji

try:
    import haiku
except:
    !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1
    import haiku

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax

import optax
from brax.v1.io import html
from IPython.display import HTML
from tqdm import tqdm

from qdax import environments
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig, PBTSacTrainingState
from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids
from qdax.core.distributed_map_elites import DistributedMAPElites
from qdax.core.emitters.pbt_me_emitter import PBTEmitter, PBTEmitterConfig
from qdax.core.emitters.pbt_variation_operators import sac_pbt_variation_fn
from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey
from qdax.utils.metrics import CSVLogger, default_qd_metrics
from qdax.utils.plotting import plot_map_elites_results
jax.config.update("jax_platform_name", "cpu")
devices = jax.devices("tpu")
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
env_name = "anttrap"

seed = 0

# SAC config
batch_size = 256
episode_length = 1000
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256) 
policy_hidden_layer_size = (256, 256) 
fix_alpha = False
normalize_observations = False

# Emitter config
buffer_size = 100000
pg_population_size_per_device = 10
ga_population_size_per_device = 30
num_training_iterations = 10000
env_batch_size = 250
grad_updates_per_step = 1.0
iso_sigma = 0.005
line_sigma = 0.05

fraction_best_to_replace_from = 0.1
fraction_to_replace_from_best = 0.2
fraction_to_replace_from_samples = 0.4

eval_env_batch_size = 1

# MAP-Elites config
num_init_cvt_samples = 50000
num_centroids = 128
log_period = 1
num_loops = 10
%%time
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * pg_population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=eval_env_batch_size,
    episode_length=episode_length,
    auto_reset=True,
)
min_bd, max_bd = env.behavior_descriptor_limits
%%time
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
env_states = jax.jit(env.reset)(rng=subkey)
eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)
# get agent
config = PBTSacConfig(
    batch_size=batch_size,
    episode_length=episode_length,
    tau=tau,
    normalize_observations=normalize_observations,
    alpha_init=alpha_init,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
    fix_alpha=fix_alpha,
)

agent = PBTSAC(config=config, action_size=env.action_size)
# init emitter
emitter_config = PBTEmitterConfig(
    buffer_size=buffer_size,
    num_training_iterations=num_training_iterations // env_batch_size,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
    pg_population_size_per_device=pg_population_size_per_device,
    ga_population_size_per_device=ga_population_size_per_device,
    num_devices=num_devices,
    fraction_best_to_replace_from=fraction_best_to_replace_from,
    fraction_to_replace_from_best=fraction_to_replace_from_best,
    fraction_to_replace_from_samples=fraction_to_replace_from_samples,
    fraction_sort_exchange=0.1,
)
variation_fn = functools.partial(
    sac_pbt_variation_fn, iso_sigma=iso_sigma, line_sigma=line_sigma
)
emitter = PBTEmitter(
    pbt_agent=agent,
    config=emitter_config,
    env=env,
    variation_fn=variation_fn,
)
# get scoring function
bd_extraction_fn = environments.behavior_descriptor_extractor[env_name]
eval_policy = agent.get_eval_qd_fn(eval_env, bd_extraction_fn=bd_extraction_fn)


def scoring_function(genotypes, random_key):
    population_size = jax.tree_util.tree_leaves(genotypes)[0].shape[0]
    first_states = jax.tree_map(
        lambda x: jnp.expand_dims(x, axis=0), eval_env_first_states
    )
    first_states = jax.tree_map(
        lambda x: jnp.repeat(x, population_size, axis=0), first_states
    )
    population_returns, population_bds, _, _ = eval_policy(genotypes, first_states)
    return population_returns, population_bds, None, random_key
# Get minimum reward value to make sure qd_score are positive
reward_offset = environments.reward_offset[env_name]

# Define a metrics function
metrics_function = functools.partial(
    default_qd_metrics,
    qd_offset=reward_offset * episode_length,
)

# Get the MAP-Elites algorithm
map_elites = DistributedMAPElites(
    scoring_function=scoring_function,
    emitter=emitter,
    metrics_function=metrics_function,
)
%%time
centroids, key = compute_cvt_centroids(
    num_descriptors=env.behavior_descriptor_length,
    num_init_cvt_samples=num_init_cvt_samples,
    num_centroids=num_centroids,
    minval=min_bd,
    maxval=max_bd,
    random_key=key,
)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
%%time
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=pg_population_size_per_device + ga_population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)
keys, training_states, _ = jax.pmap(agent_init_fn, axis_name="p", devices=devices)(keys)
%%time
# empty optimizers states to avoid too heavy repertories
training_states = jax.pmap(
    jax.vmap(training_states.__class__.empty_optimizers_states),
    axis_name="p",
    devices=devices,
)(training_states)

# initialize map-elites
repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(
    devices=devices, centroids=centroids
)(init_genotypes=training_states, random_key=keys)
update_fn = map_elites.get_distributed_update_fn(
    num_iterations=log_period, devices=devices
)
env_step_multiplier = (
    (pg_population_size_per_device + ga_population_size_per_device)
    * eval_env_batch_size
    * episode_length
    + num_training_iterations * pg_population_size_per_device
) * num_devices
%%time
all_metrics = {}

for i in tqdm(range(num_loops // log_period), total=num_loops // log_period):
    start_time = time.time()

    repertoire, emitter_state, keys, metrics = update_fn(
        repertoire, emitter_state, keys
    )
    metrics_cpu = jax.tree_map(
        lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], metrics
    )
    timelapse = time.time() - start_time

    # log metrics
    for key, value in metrics_cpu.items():
        # take all values
        if key in all_metrics.keys():
            all_metrics[key] = jnp.concatenate([all_metrics[key], value])
        else:
            all_metrics[key] = value
# Create the performance evolution plots and visualize final grid
repertoire_cpu = jax.tree_map(
    lambda x: jax.device_put(x, jax.devices("cpu")[0])[0], repertoire
)
env_steps = (jnp.arange(num_loops) + 1) * env_step_multiplier

fig, axes = plot_map_elites_results(
    env_steps=env_steps,
    metrics=all_metrics,
    repertoire=repertoire_cpu,
    min_bd=min_bd,
    max_bd=max_bd,
)

Visualize learnt behaviors

# Evaluate best individual of the repertoire
best_idx = jnp.argmax(repertoire_cpu.fitnesses)
best_fitness = jnp.max(repertoire_cpu.fitnesses)
best_bd = repertoire_cpu.descriptors[best_idx]
repertoire_cpu.descriptors.shape
# Evaluate agent that goes the further on the y-axis
# best_idx = jnp.argmax(repertoire.descriptors[:, 0])
# best_fitness = repertoire.fitnesses[best_idx]
# best_bd = repertoire.descriptors[best_idx]
print(
    f"Fitness of the selected agent: {best_fitness:.2f}\n",
    f"Behavior descriptor of the selected agent: {best_bd}\n",
    f"Index in the repertoire of this individual: {best_idx}\n",
)
env = environments.create(env_name, episode_length=episode_length)
play_step_fn = jax.pmap(
    functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
    axis_name="p",
    devices=devices[:1],
)
training_state = jax.tree_util.tree_map(lambda x: x[best_idx], repertoire_cpu.genotypes)
%%time
rollout = []

rng = jax.random.PRNGKey(seed=1)
env_state = jax.jit(env.reset)(rng=rng)

training_state, env_state = jax.tree_map(
    lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)

for _ in range(episode_length):

    rollout.append(env_state)
    env_state, training_state, _ = play_step_fn(env_state, training_state)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [
    jax.tree_map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
    for env_state in rollout
]
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))