import functools

import jax
import jax.numpy as jnp

try:
    import brax
except:
    !pip install git+https://github.com/google/brax.git@v0.9.2 |tail -n 1
    import brax

try:
    import flax
except:
    !pip install --no-deps git+https://github.com/google/flax.git@v0.7.4 |tail -n 1
    import flax

try:
    import chex
except:
    !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.83 |tail -n 1
    import chex

try:
    import jumanji
except:
    !pip install "jumanji==0.3.1"
    import jumanji

try:
    import haiku
except:
    !pip install git+https://github.com/deepmind/dm-haiku.git@v0.0.5 |tail -n 1
    import haiku

try:
    import qdax
except:
    !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1
    import qdax

from brax.v1.io import html
from IPython.display import HTML
from tqdm import tqdm

from qdax import environments
from qdax.baselines.pbt import PBT
from qdax.baselines.sac_pbt import PBTSAC, PBTSacConfig
jax.config.update("jax_platform_name", "cpu")
devices = jax.devices("tpu")
num_devices = len(devices)
print(f"Detected the following {num_devices} device(s): {devices}")
env_name = "humanoidtrap"
seed = 0
env_batch_size = 250
population_size_per_device = 10
population_size = population_size_per_device * num_devices
num_steps = 10000
buffer_size = 100000

# PBT Config
num_best_to_replace_from = 20
num_worse_to_replace = 40

# SAC config
batch_size = 256
episode_length = 1000
grad_updates_per_step = 1.0
tau = 0.005
alpha_init = 1.0
critic_hidden_layer_size = (256, 256) 
policy_hidden_layer_size = (256, 256)
fix_alpha = False
normalize_observations = False

num_loops = 10
print_freq = 1
%%time
# Initialize environments
env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)

eval_env = environments.create(
    env_name=env_name,
    batch_size=env_batch_size * population_size_per_device,
    episode_length=episode_length,
    auto_reset=True,
)
@jax.jit
def init_environments(random_key):

    env_states = jax.jit(env.reset)(rng=random_key)
    eval_env_first_states = jax.jit(eval_env.reset)(rng=random_key)

    reshape_fn = jax.jit(
        lambda tree: jax.tree_map(
            lambda x: jnp.reshape(
                x,
                (
                    population_size_per_device,
                    env_batch_size,
                )
                + x.shape[1:],
            ),
            tree,
        ),
    )
    env_states = reshape_fn(env_states)
    eval_env_first_states = reshape_fn(eval_env_first_states)

    return env_states, eval_env_first_states
# %%time
key = jax.random.PRNGKey(seed)
key, *keys = jax.random.split(key, num=1 + num_devices)
keys = jnp.stack(keys)
env_states, eval_env_first_states = jax.pmap(
    init_environments, axis_name="p", devices=devices
)(keys)
# get agent
config = PBTSacConfig(
    batch_size=batch_size,
    episode_length=episode_length,
    tau=tau,
    normalize_observations=normalize_observations,
    alpha_init=alpha_init,
    critic_hidden_layer_size=critic_hidden_layer_size,
    policy_hidden_layer_size=policy_hidden_layer_size,
    fix_alpha=fix_alpha,
)

agent = PBTSAC(config=config, action_size=env.action_size)
%%time
# get the initial training states and replay buffers
agent_init_fn = agent.get_init_fn(
    population_size=population_size_per_device,
    action_size=env.action_size,
    observation_size=env.observation_size,
    buffer_size=buffer_size,
)
keys, training_states, replay_buffers = jax.pmap(
    agent_init_fn, axis_name="p", devices=devices
)(keys)
# get eval policy fonction
eval_policy = jax.pmap(agent.get_eval_fn(eval_env), axis_name="p", devices=devices)
%%time
# eval policy before training
population_returns, _ = eval_policy(training_states, eval_env_first_states)
population_returns = jnp.reshape(population_returns, (population_size,))
print(
    f"Evaluation over {env_batch_size} episodes,"
    f" Population mean return: {jnp.mean(population_returns)},"
    f" max return: {jnp.max(population_returns)}"
)
# get training function
num_iterations = num_steps // env_batch_size

train_fn = agent.get_train_fn(
    env=env,
    num_iterations=num_iterations,
    env_batch_size=env_batch_size,
    grad_updates_per_step=grad_updates_per_step,
)
train_fn = jax.pmap(train_fn, axis_name="p", devices=devices)
pbt = PBT(
    population_size=population_size,
    num_best_to_replace_from=num_best_to_replace_from // num_devices,
    num_worse_to_replace=num_worse_to_replace // num_devices,
)
select_fn = jax.pmap(pbt.update_states_and_buffer_pmap, axis_name="p", devices=devices)
@jax.jit
def unshard_fn(sharded_tree):
    tree = jax.tree_map(lambda x: jax.device_put(x, "cpu"), sharded_tree)
    tree = jax.tree_map(
        lambda x: jnp.reshape(x, (population_size,) + x.shape[2:]), tree
    )
    return tree
%%time
for i in tqdm(range(num_loops), total=num_loops):

    # Update for num_steps
    (training_states, env_states, replay_buffers), metrics = train_fn(
        training_states, env_states, replay_buffers
    )

    # Eval policy after training
    population_returns, _ = eval_policy(training_states, eval_env_first_states)
    population_returns_flatten = jnp.reshape(population_returns, (population_size,))

    if i % print_freq == 0:
        print(
            f"Evaluation over {env_batch_size} episodes,"
            f" Population mean return: {jnp.mean(population_returns_flatten)},"
            f" max return: {jnp.max(population_returns_flatten)}"
        )

    # PBT selection
    if i < (num_loops-1):
        keys, training_states, replay_buffers = select_fn(
            keys, population_returns, training_states, replay_buffers
        )
### Visualize learnt behaviors
training_states = unshard_fn(training_states)
best_idx = jnp.argmax(population_returns)
best_training_state = jax.tree_map(lambda x: x[best_idx], training_states)
env = environments.create(env_name, episode_length=episode_length)
play_step_fn = jax.pmap(
    functools.partial(agent.play_step_fn, env=env, deterministic=True, evaluation=True),
    axis_name="p",
    devices=devices[:1],
)
training_state = best_training_state
%%time
rollout = []

rng = jax.random.PRNGKey(seed=1)
env_state = jax.jit(env.reset)(rng=rng)

training_state, env_state = jax.tree_map(
    lambda x: jnp.expand_dims(x, axis=0), (training_state, env_state)
)

for _ in range(episode_length):

    rollout.append(env_state)
    env_state, training_state, _ = play_step_fn(env_state, training_state)

print(f"The trajectory of this individual contains {len(rollout)} transitions.")
rollout = [
    jax.tree_map(lambda x: jax.device_put(x[0], jax.devices("cpu")[0]), env_state)
    for env_state in rollout
]
HTML(html.render(env.sys, [s.qp for s in rollout[:episode_length]]))