Population Based Training (PBT)¶
PBT is optimization method to jointly optimise a population of models and their hyperparameters to maximize performance.
To use PBT in QDax to train SAC, one can use the two following components (see examples to see how to use the components appropriatly):
qdax.baselines.sac_pbt.PBTSAC (SAC)
¶
Source code in qdax/baselines/sac_pbt.py
class PBTSAC(SAC):
def __init__(self, config: PBTSacConfig, action_size: int) -> None:
sac_config = SacConfig(
batch_size=config.batch_size,
episode_length=config.episode_length,
tau=config.tau,
normalize_observations=config.normalize_observations,
alpha_init=config.alpha_init,
policy_hidden_layer_size=config.policy_hidden_layer_size,
critic_hidden_layer_size=config.critic_hidden_layer_size,
fix_alpha=config.fix_alpha,
# unused default values for parameters that will be learnt as part of PBT
learning_rate=3e-4,
discount=0.97,
reward_scaling=1.0,
)
SAC.__init__(self, config=sac_config, action_size=action_size)
def init(
self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTSacTrainingState:
"""Initialise the training state of the algorithm.
Args:
random_key: a jax random key
action_size: the size of the environment's action space
observation_size: the size of the environment's observation space
Returns:
the initial training state of PBT-SAC
"""
sac_training_state = SAC.init(self, random_key, action_size, observation_size)
training_state = PBTSacTrainingState(
policy_optimizer_state=sac_training_state.policy_optimizer_state,
policy_params=sac_training_state.policy_params,
critic_optimizer_state=sac_training_state.critic_optimizer_state,
critic_params=sac_training_state.critic_params,
alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
alpha_params=sac_training_state.alpha_params,
target_critic_params=sac_training_state.target_critic_params,
normalization_running_stats=sac_training_state.normalization_running_stats,
random_key=sac_training_state.random_key,
steps=sac_training_state.steps,
discount=None,
policy_lr=None,
critic_lr=None,
alpha_lr=None,
reward_scaling=None,
)
# Sample hyper-params
training_state = PBTSacTrainingState.resample_hyperparams(training_state)
return training_state # type: ignore
@partial(jax.jit, static_argnames=("self"))
def update(
self,
training_state: PBTSacTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
"""Performs a training step to update the policy and the critic parameters.
Args:
training_state: the current PBT-SAC training state
replay_buffer: the replay buffer
Returns:
the updated PBT-SAC training state
the replay buffer
the training metrics
"""
# sample a batch of transitions in the buffer
random_key = training_state.random_key
transitions, random_key = replay_buffer.sample(
random_key,
sample_size=self._config.batch_size,
)
# normalise observations if necessary
if self._config.normalize_observations:
normalization_running_stats = training_state.normalization_running_stats
normalized_obs = normalize_with_rmstd(
transitions.obs, normalization_running_stats
)
normalized_next_obs = normalize_with_rmstd(
transitions.next_obs, normalization_running_stats
)
transitions = transitions.replace(
obs=normalized_obs, next_obs=normalized_next_obs
)
# update alpha
(
alpha_params,
alpha_optimizer_state,
alpha_loss,
random_key,
) = self._update_alpha(
alpha_lr=training_state.alpha_lr,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update critic
(
critic_params,
target_critic_params,
critic_optimizer_state,
critic_loss,
random_key,
) = self._update_critic(
critic_lr=training_state.critic_lr,
reward_scaling=training_state.reward_scaling,
discount=training_state.discount,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update actor
(
policy_params,
policy_optimizer_state,
policy_loss,
random_key,
) = self._update_actor(
policy_lr=training_state.policy_lr,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# create new training state
new_training_state = PBTSacTrainingState(
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
critic_optimizer_state=critic_optimizer_state,
critic_params=critic_params,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=alpha_params,
normalization_running_stats=training_state.normalization_running_stats,
target_critic_params=target_critic_params,
random_key=random_key,
steps=training_state.steps + 1,
discount=training_state.discount,
policy_lr=training_state.policy_lr,
critic_lr=training_state.critic_lr,
alpha_lr=training_state.alpha_lr,
reward_scaling=training_state.reward_scaling,
)
metrics = {
"actor_loss": policy_loss,
"critic_loss": critic_loss,
"alpha_loss": alpha_loss,
"obs_mean": jnp.mean(transitions.obs),
"obs_std": jnp.std(transitions.obs),
}
return new_training_state, replay_buffer, metrics
def get_init_fn(
self,
population_size: int,
action_size: int,
observation_size: int,
buffer_size: int,
) -> Callable:
"""
Returns a function to initialize the population.
Args:
population_size: size of the population.
action_size: action space size.
observation_size: observation space size.
buffer_size: replay buffer size.
Returns:
a function that takes as input a random key and returns a new random
key, the PBT population training state and the replay buffers
"""
def _init_fn(
random_key: RNGKey,
) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]:
random_key, *keys = jax.random.split(random_key, num=1 + population_size)
keys = jnp.stack(keys)
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=population_size
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
agent_init = partial(
self.init, action_size=action_size, observation_size=observation_size
)
training_states = jax.vmap(agent_init)(keys)
return random_key, training_states, replay_buffers
return _init_fn
def get_eval_fn(
self,
eval_env: Env,
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns over episodes as well as all returns from all
agents over all episodes.
"""
play_eval_step = partial(
self.play_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_policy_fn,
play_step_fn=play_eval_step,
)
return jax.vmap(eval_policy) # type: ignore
def get_eval_qd_fn(
self,
eval_env: Env,
bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
bd_extraction_fn: function to extract the bd from an episode.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns and mean bds over episodes as well as all
returns and bds from all agents over all episodes.
"""
play_eval_step = partial(
self.play_qd_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_qd_policy_fn,
play_step_fn=play_eval_step,
bd_extraction_fn=bd_extraction_fn,
)
return jax.vmap(eval_policy) # type: ignore
def get_train_fn(
self,
env: Env,
num_iterations: int,
env_batch_size: int,
grad_updates_per_step: float,
) -> Callable:
"""
Returns the function to update the population of agents.
Args:
env: training environment.
num_iterations: number of training iterations to perform.
env_batch_size: number of batched environments.
grad_updates_per_step: number of gradient to apply per step in the
environment.
Returns:
the function to update the population which takes as input the population
training state, environment starting states and replay buffers and returns
updated training states, environment states, replay buffers and metrics.
"""
play_step = partial(
self.play_step_fn,
env=env,
deterministic=False,
)
do_iteration = partial(
do_iteration_fn,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
play_step_fn=play_step,
update_fn=self.update,
)
def _scan_do_iteration(
carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
unused_arg: Any,
) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
(
training_state,
env_state,
replay_buffer,
metrics,
) = do_iteration(*carry)
return (training_state, env_state, replay_buffer), metrics
def train_fn(
training_state: PBTSacTrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
_scan_do_iteration,
(training_state, env_state, replay_buffer),
None,
length=num_iterations,
)
return (training_state, env_state, replay_buffer), metrics
return jax.vmap(train_fn) # type: ignore
init(self, random_key, action_size, observation_size)
¶
Initialise the training state of the algorithm.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
def init(
self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTSacTrainingState:
"""Initialise the training state of the algorithm.
Args:
random_key: a jax random key
action_size: the size of the environment's action space
observation_size: the size of the environment's observation space
Returns:
the initial training state of PBT-SAC
"""
sac_training_state = SAC.init(self, random_key, action_size, observation_size)
training_state = PBTSacTrainingState(
policy_optimizer_state=sac_training_state.policy_optimizer_state,
policy_params=sac_training_state.policy_params,
critic_optimizer_state=sac_training_state.critic_optimizer_state,
critic_params=sac_training_state.critic_params,
alpha_optimizer_state=sac_training_state.alpha_optimizer_state,
alpha_params=sac_training_state.alpha_params,
target_critic_params=sac_training_state.target_critic_params,
normalization_running_stats=sac_training_state.normalization_running_stats,
random_key=sac_training_state.random_key,
steps=sac_training_state.steps,
discount=None,
policy_lr=None,
critic_lr=None,
alpha_lr=None,
reward_scaling=None,
)
# Sample hyper-params
training_state = PBTSacTrainingState.resample_hyperparams(training_state)
return training_state # type: ignore
update(self, training_state, replay_buffer)
¶
Performs a training step to update the policy and the critic parameters.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
@partial(jax.jit, static_argnames=("self"))
def update(
self,
training_state: PBTSacTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[PBTSacTrainingState, ReplayBuffer, Metrics]:
"""Performs a training step to update the policy and the critic parameters.
Args:
training_state: the current PBT-SAC training state
replay_buffer: the replay buffer
Returns:
the updated PBT-SAC training state
the replay buffer
the training metrics
"""
# sample a batch of transitions in the buffer
random_key = training_state.random_key
transitions, random_key = replay_buffer.sample(
random_key,
sample_size=self._config.batch_size,
)
# normalise observations if necessary
if self._config.normalize_observations:
normalization_running_stats = training_state.normalization_running_stats
normalized_obs = normalize_with_rmstd(
transitions.obs, normalization_running_stats
)
normalized_next_obs = normalize_with_rmstd(
transitions.next_obs, normalization_running_stats
)
transitions = transitions.replace(
obs=normalized_obs, next_obs=normalized_next_obs
)
# update alpha
(
alpha_params,
alpha_optimizer_state,
alpha_loss,
random_key,
) = self._update_alpha(
alpha_lr=training_state.alpha_lr,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update critic
(
critic_params,
target_critic_params,
critic_optimizer_state,
critic_loss,
random_key,
) = self._update_critic(
critic_lr=training_state.critic_lr,
reward_scaling=training_state.reward_scaling,
discount=training_state.discount,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update actor
(
policy_params,
policy_optimizer_state,
policy_loss,
random_key,
) = self._update_actor(
policy_lr=training_state.policy_lr,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# create new training state
new_training_state = PBTSacTrainingState(
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
critic_optimizer_state=critic_optimizer_state,
critic_params=critic_params,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=alpha_params,
normalization_running_stats=training_state.normalization_running_stats,
target_critic_params=target_critic_params,
random_key=random_key,
steps=training_state.steps + 1,
discount=training_state.discount,
policy_lr=training_state.policy_lr,
critic_lr=training_state.critic_lr,
alpha_lr=training_state.alpha_lr,
reward_scaling=training_state.reward_scaling,
)
metrics = {
"actor_loss": policy_loss,
"critic_loss": critic_loss,
"alpha_loss": alpha_loss,
"obs_mean": jnp.mean(transitions.obs),
"obs_std": jnp.std(transitions.obs),
}
return new_training_state, replay_buffer, metrics
get_init_fn(self, population_size, action_size, observation_size, buffer_size)
¶
Returns a function to initialize the population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
def get_init_fn(
self,
population_size: int,
action_size: int,
observation_size: int,
buffer_size: int,
) -> Callable:
"""
Returns a function to initialize the population.
Args:
population_size: size of the population.
action_size: action space size.
observation_size: observation space size.
buffer_size: replay buffer size.
Returns:
a function that takes as input a random key and returns a new random
key, the PBT population training state and the replay buffers
"""
def _init_fn(
random_key: RNGKey,
) -> Tuple[RNGKey, PBTSacTrainingState, ReplayBuffer]:
random_key, *keys = jax.random.split(random_key, num=1 + population_size)
keys = jnp.stack(keys)
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=population_size
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
agent_init = partial(
self.init, action_size=action_size, observation_size=observation_size
)
training_states = jax.vmap(agent_init)(keys)
return random_key, training_states, replay_buffers
return _init_fn
get_eval_fn(self, eval_env)
¶
Returns the function the evaluation the PBT population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
def get_eval_fn(
self,
eval_env: Env,
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns over episodes as well as all returns from all
agents over all episodes.
"""
play_eval_step = partial(
self.play_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_policy_fn,
play_step_fn=play_eval_step,
)
return jax.vmap(eval_policy) # type: ignore
get_eval_qd_fn(self, eval_env, bd_extraction_fn)
¶
Returns the function the evaluation the PBT population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
def get_eval_qd_fn(
self,
eval_env: Env,
bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
bd_extraction_fn: function to extract the bd from an episode.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns and mean bds over episodes as well as all
returns and bds from all agents over all episodes.
"""
play_eval_step = partial(
self.play_qd_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_qd_policy_fn,
play_step_fn=play_eval_step,
bd_extraction_fn=bd_extraction_fn,
)
return jax.vmap(eval_policy) # type: ignore
get_train_fn(self, env, num_iterations, env_batch_size, grad_updates_per_step)
¶
Returns the function to update the population of agents.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/sac_pbt.py
def get_train_fn(
self,
env: Env,
num_iterations: int,
env_batch_size: int,
grad_updates_per_step: float,
) -> Callable:
"""
Returns the function to update the population of agents.
Args:
env: training environment.
num_iterations: number of training iterations to perform.
env_batch_size: number of batched environments.
grad_updates_per_step: number of gradient to apply per step in the
environment.
Returns:
the function to update the population which takes as input the population
training state, environment starting states and replay buffers and returns
updated training states, environment states, replay buffers and metrics.
"""
play_step = partial(
self.play_step_fn,
env=env,
deterministic=False,
)
do_iteration = partial(
do_iteration_fn,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
play_step_fn=play_step,
update_fn=self.update,
)
def _scan_do_iteration(
carry: Tuple[PBTSacTrainingState, EnvState, ReplayBuffer],
unused_arg: Any,
) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Any]:
(
training_state,
env_state,
replay_buffer,
metrics,
) = do_iteration(*carry)
return (training_state, env_state, replay_buffer), metrics
def train_fn(
training_state: PBTSacTrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
) -> Tuple[Tuple[PBTSacTrainingState, EnvState, ReplayBuffer], Metrics]:
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
_scan_do_iteration,
(training_state, env_state, replay_buffer),
None,
length=num_iterations,
)
return (training_state, env_state, replay_buffer), metrics
return jax.vmap(train_fn) # type: ignore
and
qdax.baselines.pbt.PBT
¶
This class serves as a template for algorithm that want to implement the standard Population Based Training (PBT) scheme.
Source code in qdax/baselines/pbt.py
class PBT:
"""
This class serves as a template for algorithm that want to implement the standard
Population Based Training (PBT) scheme.
"""
def __init__(
self,
population_size: int,
num_best_to_replace_from: int,
num_worse_to_replace: int,
):
"""
Args:
population_size: Size of the PBT population.
num_best_to_replace_from: Number of top performing individuals to sample
from when replacing low performers at each iteration.
num_worse_to_replace: Number of low-performing individuals to replace at
each iteration.
"""
if num_best_to_replace_from + num_worse_to_replace > population_size:
raise ValueError(
"The sum of best number of individuals to replace "
"from and worse individuals to replace exceeds the population size."
)
self._population_size = population_size
self._num_best_to_replace_from = num_best_to_replace_from
self._num_worse_to_replace = num_worse_to_replace
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer(
self,
random_key: RNGKey,
population_returns: jnp.ndarray,
training_state: PBTTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
"""
Updates the agents of the population states as well as
their shared replay buffer.
Args:
random_key: Random RNG key.
population_returns: Returns of the agents in the populations.
training_state: The training state of the PBT scheme.
replay_buffer: Shared replay buffer by the agents.
Returns:
Updated random key, updated PBT training state and updated replay buffer.
"""
indices_sorted = jax.numpy.argsort(-population_returns)
best_indices = indices_sorted[: self._num_best_to_replace_from]
indices_to_replace = indices_sorted[-self._num_worse_to_replace :]
random_key, key = jax.random.split(random_key)
indices_used_to_replace = jax.random.choice(
key, best_indices, shape=(self._num_worse_to_replace,), replace=True
)
training_state = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
training_state,
jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
)
replay_buffer = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
replay_buffer,
replay_buffer,
)
return random_key, training_state, replay_buffer
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer_pmap(
self,
random_key: RNGKey,
population_returns: jnp.ndarray,
training_state: PBTTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
"""
Updates the agents of the population states as well as
their shared replay buffer. This is the version of the function to be
used within jax.pmap. It makes the population is spread over several devices
and implement a parallel update through communication between the devices.
Args:
random_key: Random RNG key.
population_returns: Returns of the agents in the populations.
training_state: The training state of the PBT scheme.
replay_buffer: Shared replay buffer by the agents.
Returns:
Updated random key, updated PBT training state and updated replay buffer.
"""
indices_sorted = jax.numpy.argsort(-population_returns)
best_indices = indices_sorted[: self._num_best_to_replace_from]
indices_to_replace = indices_sorted[-self._num_worse_to_replace :]
best_individuals, best_buffers, best_returns = jax.tree_util.tree_map(
lambda x: x[best_indices],
(training_state, replay_buffer, population_returns),
)
(
gathered_best_individuals,
gathered_best_buffers,
gathered_best_returns,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(best_individuals, best_buffers, best_returns),
)
pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]
random_key, key = jax.random.split(random_key)
indices_used_to_replace = jax.random.choice(
key, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
)
training_state = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
training_state,
jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
gathered_best_individuals
),
)
replay_buffer = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
replay_buffer,
gathered_best_buffers,
)
return random_key, training_state, replay_buffer
__init__(self, population_size, num_best_to_replace_from, num_worse_to_replace)
special
¶
Parameters: |
|
---|
Source code in qdax/baselines/pbt.py
def __init__(
self,
population_size: int,
num_best_to_replace_from: int,
num_worse_to_replace: int,
):
"""
Args:
population_size: Size of the PBT population.
num_best_to_replace_from: Number of top performing individuals to sample
from when replacing low performers at each iteration.
num_worse_to_replace: Number of low-performing individuals to replace at
each iteration.
"""
if num_best_to_replace_from + num_worse_to_replace > population_size:
raise ValueError(
"The sum of best number of individuals to replace "
"from and worse individuals to replace exceeds the population size."
)
self._population_size = population_size
self._num_best_to_replace_from = num_best_to_replace_from
self._num_worse_to_replace = num_worse_to_replace
update_states_and_buffer(self, random_key, population_returns, training_state, replay_buffer)
¶
Updates the agents of the population states as well as their shared replay buffer.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/pbt.py
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer(
self,
random_key: RNGKey,
population_returns: jnp.ndarray,
training_state: PBTTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
"""
Updates the agents of the population states as well as
their shared replay buffer.
Args:
random_key: Random RNG key.
population_returns: Returns of the agents in the populations.
training_state: The training state of the PBT scheme.
replay_buffer: Shared replay buffer by the agents.
Returns:
Updated random key, updated PBT training state and updated replay buffer.
"""
indices_sorted = jax.numpy.argsort(-population_returns)
best_indices = indices_sorted[: self._num_best_to_replace_from]
indices_to_replace = indices_sorted[-self._num_worse_to_replace :]
random_key, key = jax.random.split(random_key)
indices_used_to_replace = jax.random.choice(
key, best_indices, shape=(self._num_worse_to_replace,), replace=True
)
training_state = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
training_state,
jax.vmap(training_state.__class__.resample_hyperparams)(training_state),
)
replay_buffer = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
replay_buffer,
replay_buffer,
)
return random_key, training_state, replay_buffer
update_states_and_buffer_pmap(self, random_key, population_returns, training_state, replay_buffer)
¶
Updates the agents of the population states as well as their shared replay buffer. This is the version of the function to be used within jax.pmap. It makes the population is spread over several devices and implement a parallel update through communication between the devices.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/pbt.py
@partial(jax.jit, static_argnames=("self",))
def update_states_and_buffer_pmap(
self,
random_key: RNGKey,
population_returns: jnp.ndarray,
training_state: PBTTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[RNGKey, PBTTrainingState, ReplayBuffer]:
"""
Updates the agents of the population states as well as
their shared replay buffer. This is the version of the function to be
used within jax.pmap. It makes the population is spread over several devices
and implement a parallel update through communication between the devices.
Args:
random_key: Random RNG key.
population_returns: Returns of the agents in the populations.
training_state: The training state of the PBT scheme.
replay_buffer: Shared replay buffer by the agents.
Returns:
Updated random key, updated PBT training state and updated replay buffer.
"""
indices_sorted = jax.numpy.argsort(-population_returns)
best_indices = indices_sorted[: self._num_best_to_replace_from]
indices_to_replace = indices_sorted[-self._num_worse_to_replace :]
best_individuals, best_buffers, best_returns = jax.tree_util.tree_map(
lambda x: x[best_indices],
(training_state, replay_buffer, population_returns),
)
(
gathered_best_individuals,
gathered_best_buffers,
gathered_best_returns,
) = jax.tree_util.tree_map(
lambda x: jnp.concatenate(jax.lax.all_gather(x, axis_name="p"), axis=0),
(best_individuals, best_buffers, best_returns),
)
pop_indices_sorted = jax.numpy.argsort(-gathered_best_returns)
best_pop_indices = pop_indices_sorted[: self._num_best_to_replace_from]
random_key, key = jax.random.split(random_key)
indices_used_to_replace = jax.random.choice(
key, best_pop_indices, shape=(self._num_worse_to_replace,), replace=True
)
training_state = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
training_state,
jax.vmap(gathered_best_individuals.__class__.resample_hyperparams)(
gathered_best_individuals
),
)
replay_buffer = jax.tree_util.tree_map(
lambda x, y: x.at[indices_to_replace].set(y[indices_used_to_replace]),
replay_buffer,
gathered_best_buffers,
)
return random_key, training_state, replay_buffer
To use PBT in order to train TD3 agents, please use the PBTTD3 class:
qdax.baselines.td3_pbt.PBTTD3 (TD3)
¶
Source code in qdax/baselines/td3_pbt.py
class PBTTD3(TD3):
def __init__(self, config: PBTTD3Config, action_size: int):
td3_config = TD3Config(
episode_length=config.episode_length,
batch_size=config.batch_size,
policy_delay=config.policy_delay,
reward_scaling=config.reward_scaling,
soft_tau_update=config.soft_tau_update,
critic_hidden_layer_size=config.critic_hidden_layer_size,
policy_hidden_layer_size=config.policy_hidden_layer_size,
)
TD3.__init__(self, td3_config, action_size)
def init(
self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTTD3TrainingState:
"""Initialise the training state of the PBT-TD3 algorithm, through creation
of optimizer states and params.
Args:
random_key: a random key used for random operations.
action_size: the size of the action array needed to interact with the
environment.
observation_size: the size of the observation array retrieved from the
environment.
Returns:
the initial training state.
"""
training_state = TD3.init(self, random_key, action_size, observation_size)
# Initial training state
training_state = PBTTD3TrainingState(
policy_optimizer_state=training_state.policy_optimizer_state,
policy_params=training_state.policy_params,
critic_optimizer_state=training_state.critic_optimizer_state,
critic_params=training_state.critic_params,
target_policy_params=training_state.target_policy_params,
target_critic_params=training_state.target_critic_params,
random_key=training_state.random_key,
steps=training_state.steps,
discount=None,
policy_lr=None,
critic_lr=None,
noise_clip=None,
policy_noise=None,
expl_noise=None,
)
# Sample hyperparameters
training_state = PBTTD3TrainingState.resample_hyperparams(training_state)
return training_state # type: ignore
@partial(jax.jit, static_argnames=("self", "env", "deterministic"))
def play_step_fn(
self,
env_state: EnvState,
training_state: TD3TrainingState,
env: Env,
deterministic: bool = False,
) -> Tuple[EnvState, TD3TrainingState, Transition]:
"""Plays a step in the environment. Selects an action according to TD3 rule and
performs the environment step.
Args:
env_state: the current environment state
training_state: the PBT-TD3 training state
env: the environment
deterministic: whether to select action in a deterministic way.
Defaults to False.
Returns:
the new environment state
the new PBT-TD3 training state
the played transition
"""
actions, random_key = self.select_action(
obs=env_state.obs,
policy_params=training_state.policy_params,
random_key=training_state.random_key,
expl_noise=training_state.expl_noise,
deterministic=deterministic,
)
training_state = training_state.replace(
random_key=random_key,
)
next_env_state = env.step(env_state, actions)
transition = Transition(
obs=env_state.obs,
next_obs=next_env_state.obs,
rewards=next_env_state.reward,
dones=next_env_state.done,
truncations=next_env_state.info["truncation"],
actions=actions,
)
return next_env_state, training_state, transition
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: PBTTD3TrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
"""Performs a single training step: updates policy params and critic params
through gradient descent.
Args:
training_state: the current training state, containing the optimizer states
and the params of the policy and critic.
replay_buffer: the replay buffer, filled with transitions experienced in
the environment.
Returns:
A new training state, the buffer with new transitions and metrics about the
training process.
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, random_key = replay_buffer.sample(
random_key, sample_size=self._config.batch_size
)
# Update Critic
random_key, subkey = jax.random.split(random_key)
critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
training_state.critic_params,
target_policy_params=training_state.target_policy_params,
target_critic_params=training_state.target_critic_params,
policy_fn=self._policy.apply,
critic_fn=self._critic.apply,
policy_noise=training_state.policy_noise,
noise_clip=training_state.noise_clip,
reward_scaling=self._config.reward_scaling,
discount=self._config.discount,
transitions=samples,
random_key=subkey,
)
critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
critic_updates, critic_optimizer_state = critic_optimizer.update(
critic_gradient, training_state.critic_optimizer_state
)
critic_params = optax.apply_updates(
training_state.critic_params, critic_updates
)
# Soft update of target critic network
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_critic_params,
critic_params,
)
# Update policy
policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
training_state.policy_params,
critic_params=training_state.critic_params,
policy_fn=self._policy.apply,
critic_fn=self._critic.apply,
transitions=samples,
)
def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
(policy_updates, policy_optimizer_state,) = policy_optimizer.update(
policy_gradient, training_state.policy_optimizer_state
)
policy_params = optax.apply_updates(
training_state.policy_params, policy_updates
)
# Soft update of target policy
target_policy_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_policy_params,
policy_params,
)
return policy_params, target_policy_params, policy_optimizer_state
# Delayed update
current_policy_state = (
training_state.policy_params,
training_state.target_policy_params,
training_state.policy_optimizer_state,
)
policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
training_state.steps % self._config.policy_delay == 0,
lambda _: update_policy_step(),
lambda _: current_policy_state,
operand=None,
)
# Create new training state
new_training_state = training_state.replace(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
policy_params=policy_params,
policy_optimizer_state=policy_optimizer_state,
target_critic_params=target_critic_params,
target_policy_params=target_policy_params,
random_key=random_key,
steps=training_state.steps + 1,
)
metrics = {
"actor_loss": policy_loss,
"critic_loss": critic_loss,
}
return new_training_state, replay_buffer, metrics
def get_init_fn(
self,
population_size: int,
action_size: int,
observation_size: int,
buffer_size: int,
) -> Callable:
"""
Returns a function to initialize the population.
Args:
population_size: size of the population.
action_size: action space size.
observation_size: observation space size.
buffer_size: replay buffer size.
Returns:
a function that takes as input a random key and returns a new random
key, the PBT population training state and the replay buffers
"""
def _init_fn(
random_key: RNGKey,
) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]:
random_key, *keys = jax.random.split(random_key, num=1 + population_size)
keys = jnp.stack(keys)
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=population_size
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
agent_init = partial(
self.init, action_size=action_size, observation_size=observation_size
)
training_states = jax.vmap(agent_init)(keys)
return random_key, training_states, replay_buffers
return _init_fn
def get_eval_fn(
self,
eval_env: Env,
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns over episodes as well as all returns from all
agents over all episodes.
"""
play_eval_step = partial(
self.play_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_policy_fn,
play_step_fn=play_eval_step,
)
return jax.vmap(eval_policy) # type: ignore
def get_eval_qd_fn(
self,
eval_env: Env,
bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
bd_extraction_fn: function to extract the bd from an episode.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns and mean bds over episodes as well as all
returns and bds from all agents over all episodes.
"""
play_eval_step = partial(
self.play_qd_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_qd_policy_fn,
play_step_fn=play_eval_step,
bd_extraction_fn=bd_extraction_fn,
)
return jax.vmap(eval_policy) # type: ignore
def get_train_fn(
self,
env: Env,
num_iterations: int,
env_batch_size: int,
grad_updates_per_step: float,
) -> Callable:
"""
Returns the function to update the population of agents.
Args:
env: training environment.
num_iterations: number of training iterations to perform.
env_batch_size: number of batched environments.
grad_updates_per_step: number of gradient to apply per step in the
environment.
Returns:
the function to update the population which takes as input the population
training state, environment starting states and replay buffers and returns
updated training states, environment states, replay buffers and metrics.
"""
play_step = partial(
self.play_step_fn,
env=env,
deterministic=False,
)
do_iteration = partial(
do_iteration_fn,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
play_step_fn=play_step,
update_fn=self.update,
)
def _scan_do_iteration(
carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
unused_arg: Any,
) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
(
training_state,
env_state,
replay_buffer,
metrics,
) = do_iteration(*carry)
return (training_state, env_state, replay_buffer), metrics
def train_fn(
training_state: PBTTD3TrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
_scan_do_iteration,
(training_state, env_state, replay_buffer),
None,
length=num_iterations,
)
return (training_state, env_state, replay_buffer), metrics
return jax.vmap(train_fn) # type: ignore
init(self, random_key, action_size, observation_size)
¶
Initialise the training state of the PBT-TD3 algorithm, through creation of optimizer states and params.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
def init(
self, random_key: RNGKey, action_size: int, observation_size: int
) -> PBTTD3TrainingState:
"""Initialise the training state of the PBT-TD3 algorithm, through creation
of optimizer states and params.
Args:
random_key: a random key used for random operations.
action_size: the size of the action array needed to interact with the
environment.
observation_size: the size of the observation array retrieved from the
environment.
Returns:
the initial training state.
"""
training_state = TD3.init(self, random_key, action_size, observation_size)
# Initial training state
training_state = PBTTD3TrainingState(
policy_optimizer_state=training_state.policy_optimizer_state,
policy_params=training_state.policy_params,
critic_optimizer_state=training_state.critic_optimizer_state,
critic_params=training_state.critic_params,
target_policy_params=training_state.target_policy_params,
target_critic_params=training_state.target_critic_params,
random_key=training_state.random_key,
steps=training_state.steps,
discount=None,
policy_lr=None,
critic_lr=None,
noise_clip=None,
policy_noise=None,
expl_noise=None,
)
# Sample hyperparameters
training_state = PBTTD3TrainingState.resample_hyperparams(training_state)
return training_state # type: ignore
play_step_fn(self, env_state, training_state, env, deterministic=False)
¶
Plays a step in the environment. Selects an action according to TD3 rule and performs the environment step.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic"))
def play_step_fn(
self,
env_state: EnvState,
training_state: TD3TrainingState,
env: Env,
deterministic: bool = False,
) -> Tuple[EnvState, TD3TrainingState, Transition]:
"""Plays a step in the environment. Selects an action according to TD3 rule and
performs the environment step.
Args:
env_state: the current environment state
training_state: the PBT-TD3 training state
env: the environment
deterministic: whether to select action in a deterministic way.
Defaults to False.
Returns:
the new environment state
the new PBT-TD3 training state
the played transition
"""
actions, random_key = self.select_action(
obs=env_state.obs,
policy_params=training_state.policy_params,
random_key=training_state.random_key,
expl_noise=training_state.expl_noise,
deterministic=deterministic,
)
training_state = training_state.replace(
random_key=random_key,
)
next_env_state = env.step(env_state, actions)
transition = Transition(
obs=env_state.obs,
next_obs=next_env_state.obs,
rewards=next_env_state.reward,
dones=next_env_state.done,
truncations=next_env_state.info["truncation"],
actions=actions,
)
return next_env_state, training_state, transition
update(self, training_state, replay_buffer)
¶
Performs a single training step: updates policy params and critic params through gradient descent.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: PBTTD3TrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[PBTTD3TrainingState, ReplayBuffer, Metrics]:
"""Performs a single training step: updates policy params and critic params
through gradient descent.
Args:
training_state: the current training state, containing the optimizer states
and the params of the policy and critic.
replay_buffer: the replay buffer, filled with transitions experienced in
the environment.
Returns:
A new training state, the buffer with new transitions and metrics about the
training process.
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, random_key = replay_buffer.sample(
random_key, sample_size=self._config.batch_size
)
# Update Critic
random_key, subkey = jax.random.split(random_key)
critic_loss, critic_gradient = jax.value_and_grad(td3_critic_loss_fn)(
training_state.critic_params,
target_policy_params=training_state.target_policy_params,
target_critic_params=training_state.target_critic_params,
policy_fn=self._policy.apply,
critic_fn=self._critic.apply,
policy_noise=training_state.policy_noise,
noise_clip=training_state.noise_clip,
reward_scaling=self._config.reward_scaling,
discount=self._config.discount,
transitions=samples,
random_key=subkey,
)
critic_optimizer = optax.adam(learning_rate=training_state.critic_lr)
critic_updates, critic_optimizer_state = critic_optimizer.update(
critic_gradient, training_state.critic_optimizer_state
)
critic_params = optax.apply_updates(
training_state.critic_params, critic_updates
)
# Soft update of target critic network
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_critic_params,
critic_params,
)
# Update policy
policy_loss, policy_gradient = jax.value_and_grad(td3_policy_loss_fn)(
training_state.policy_params,
critic_params=training_state.critic_params,
policy_fn=self._policy.apply,
critic_fn=self._critic.apply,
transitions=samples,
)
def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
policy_optimizer = optax.adam(learning_rate=training_state.policy_lr)
(policy_updates, policy_optimizer_state,) = policy_optimizer.update(
policy_gradient, training_state.policy_optimizer_state
)
policy_params = optax.apply_updates(
training_state.policy_params, policy_updates
)
# Soft update of target policy
target_policy_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_policy_params,
policy_params,
)
return policy_params, target_policy_params, policy_optimizer_state
# Delayed update
current_policy_state = (
training_state.policy_params,
training_state.target_policy_params,
training_state.policy_optimizer_state,
)
policy_params, target_policy_params, policy_optimizer_state = jax.lax.cond(
training_state.steps % self._config.policy_delay == 0,
lambda _: update_policy_step(),
lambda _: current_policy_state,
operand=None,
)
# Create new training state
new_training_state = training_state.replace(
critic_params=critic_params,
critic_optimizer_state=critic_optimizer_state,
policy_params=policy_params,
policy_optimizer_state=policy_optimizer_state,
target_critic_params=target_critic_params,
target_policy_params=target_policy_params,
random_key=random_key,
steps=training_state.steps + 1,
)
metrics = {
"actor_loss": policy_loss,
"critic_loss": critic_loss,
}
return new_training_state, replay_buffer, metrics
get_init_fn(self, population_size, action_size, observation_size, buffer_size)
¶
Returns a function to initialize the population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
def get_init_fn(
self,
population_size: int,
action_size: int,
observation_size: int,
buffer_size: int,
) -> Callable:
"""
Returns a function to initialize the population.
Args:
population_size: size of the population.
action_size: action space size.
observation_size: observation space size.
buffer_size: replay buffer size.
Returns:
a function that takes as input a random key and returns a new random
key, the PBT population training state and the replay buffers
"""
def _init_fn(
random_key: RNGKey,
) -> Tuple[RNGKey, PBTTD3TrainingState, ReplayBuffer]:
random_key, *keys = jax.random.split(random_key, num=1 + population_size)
keys = jnp.stack(keys)
init_dummy_transition = partial(
Transition.init_dummy,
observation_dim=observation_size,
action_dim=action_size,
)
init_dummy_transition = jax.vmap(
init_dummy_transition, axis_size=population_size
)
dummy_transitions = init_dummy_transition()
replay_buffer_init = partial(
ReplayBuffer.init,
buffer_size=buffer_size,
)
replay_buffer_init = jax.vmap(replay_buffer_init)
replay_buffers = replay_buffer_init(transition=dummy_transitions)
agent_init = partial(
self.init, action_size=action_size, observation_size=observation_size
)
training_states = jax.vmap(agent_init)(keys)
return random_key, training_states, replay_buffers
return _init_fn
get_eval_fn(self, eval_env)
¶
Returns the function the evaluation the PBT population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
def get_eval_fn(
self,
eval_env: Env,
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns over episodes as well as all returns from all
agents over all episodes.
"""
play_eval_step = partial(
self.play_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_policy_fn,
play_step_fn=play_eval_step,
)
return jax.vmap(eval_policy) # type: ignore
get_eval_qd_fn(self, eval_env, bd_extraction_fn)
¶
Returns the function the evaluation the PBT population.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
def get_eval_qd_fn(
self,
eval_env: Env,
bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Callable:
"""
Returns the function the evaluation the PBT population.
Args:
eval_env: evaluation environment. Might be different from training env
if needed.
bd_extraction_fn: function to extract the bd from an episode.
Returns:
The function to evaluate the population. It takes as input the population
training state as well as first eval environment states and returns the
population agents mean returns and mean bds over episodes as well as all
returns and bds from all agents over all episodes.
"""
play_eval_step = partial(
self.play_qd_step_fn,
env=eval_env,
deterministic=True,
)
eval_policy = partial(
self.eval_qd_policy_fn,
play_step_fn=play_eval_step,
bd_extraction_fn=bd_extraction_fn,
)
return jax.vmap(eval_policy) # type: ignore
get_train_fn(self, env, num_iterations, env_batch_size, grad_updates_per_step)
¶
Returns the function to update the population of agents.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/td3_pbt.py
def get_train_fn(
self,
env: Env,
num_iterations: int,
env_batch_size: int,
grad_updates_per_step: float,
) -> Callable:
"""
Returns the function to update the population of agents.
Args:
env: training environment.
num_iterations: number of training iterations to perform.
env_batch_size: number of batched environments.
grad_updates_per_step: number of gradient to apply per step in the
environment.
Returns:
the function to update the population which takes as input the population
training state, environment starting states and replay buffers and returns
updated training states, environment states, replay buffers and metrics.
"""
play_step = partial(
self.play_step_fn,
env=env,
deterministic=False,
)
do_iteration = partial(
do_iteration_fn,
env_batch_size=env_batch_size,
grad_updates_per_step=grad_updates_per_step,
play_step_fn=play_step,
update_fn=self.update,
)
def _scan_do_iteration(
carry: Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer],
unused_arg: Any,
) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Any]:
(
training_state,
env_state,
replay_buffer,
metrics,
) = do_iteration(*carry)
return (training_state, env_state, replay_buffer), metrics
def train_fn(
training_state: PBTTD3TrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
) -> Tuple[Tuple[PBTTD3TrainingState, EnvState, ReplayBuffer], Metrics]:
(training_state, env_state, replay_buffer), metrics = jax.lax.scan(
_scan_do_iteration,
(training_state, env_state, replay_buffer),
None,
length=num_iterations,
)
return (training_state, env_state, replay_buffer), metrics
return jax.vmap(train_fn) # type: ignore