DADS class¶
qdax.baselines.dads.DADS (SAC)
¶
Implements DADS algorithm https://arxiv.org/abs/1907.01657.
Note that the functions select_action, _update_alpha, _update_critic and _update_actor are inherited from SAC algorithm.
In the current implementation, we suppose that the skills are fixed one hot vectors, and do not support continuous skills at the moment.
Also, we suppose that the skills are evaluated in parallel in a fixed manner: a batch of environments, containing a multiple of the number of skills, is used to evaluate the skills in the environment and hence to generate transitions. The sampling is hence fixed and perfectly uniform.
We plan to add continous skill as an option in the future. We also plan to release the current constraint on the number of batched environments by sampling from the skills rather than having this fixed setting.
Source code in qdax/baselines/dads.py
class DADS(SAC):
"""Implements DADS algorithm https://arxiv.org/abs/1907.01657.
Note that the functions select_action, _update_alpha, _update_critic and
_update_actor are inherited from SAC algorithm.
In the current implementation, we suppose that the skills are fixed one
hot vectors, and do not support continuous skills at the moment.
Also, we suppose that the skills are evaluated in parallel in a fixed
manner: a batch of environments, containing a multiple of the number
of skills, is used to evaluate the skills in the environment and hence
to generate transitions. The sampling is hence fixed and perfectly uniform.
We plan to add continous skill as an option in the future. We also plan
to release the current constraint on the number of batched environments
by sampling from the skills rather than having this fixed setting.
"""
def __init__(self, config: DadsConfig, action_size: int, descriptor_size: int):
self._config: DadsConfig = config
if self._config.normalize_observations:
raise NotImplementedError("Normalization in not implemented for DADS yet")
# define the networks
self._policy, self._critic, self._dynamics = make_dads_networks(
action_size=action_size,
descriptor_size=descriptor_size,
omit_input_dynamics_dim=config.omit_input_dynamics_dim,
policy_hidden_layer_size=config.policy_hidden_layer_size,
critic_hidden_layer_size=config.critic_hidden_layer_size,
)
# define the action distribution
self._action_size = action_size
self._parametric_action_distribution = NormalTanhDistribution(
event_size=action_size
)
self._sample_action_fn = self._parametric_action_distribution.sample
# define the losses
(
self._alpha_loss_fn,
self._policy_loss_fn,
self._critic_loss_fn,
self._dynamics_loss_fn,
) = make_dads_loss_fn(
policy_fn=self._policy.apply,
critic_fn=self._critic.apply,
dynamics_fn=self._dynamics.apply,
reward_scaling=self._config.reward_scaling,
discount=self._config.discount,
action_size=action_size,
num_skills=self._config.num_skills,
parametric_action_distribution=self._parametric_action_distribution,
)
# define the optimizers
self._policy_optimizer = optax.adam(learning_rate=self._config.learning_rate)
self._critic_optimizer = optax.adam(learning_rate=self._config.learning_rate)
self._alpha_optimizer = optax.adam(learning_rate=self._config.learning_rate)
self._dynamics_optimizer = optax.adam(learning_rate=self._config.learning_rate)
def init( # type: ignore
self,
random_key: RNGKey,
action_size: int,
observation_size: int,
descriptor_size: int,
) -> DadsTrainingState:
"""Initialise the training state of the algorithm.
Args:
random_key: a jax random key
action_size: the size of the environment's action space
observation_size: the size of the environment's observation space
descriptor_size: the size of the environment's descriptor space (i.e. the
dimension of the dynamics network's input)
Returns:
the initial training state of DADS
"""
# Initialize params
dummy_obs = jnp.zeros((1, observation_size + self._config.num_skills))
dummy_action = jnp.zeros((1, action_size))
dummy_dyn_obs = jnp.zeros((1, descriptor_size))
dummy_skill = jnp.zeros((1, self._config.num_skills))
random_key, subkey = jax.random.split(random_key)
policy_params = self._policy.init(subkey, dummy_obs)
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)
random_key, subkey = jax.random.split(random_key)
dynamics_params = self._dynamics.init(
subkey,
obs=dummy_dyn_obs,
skill=dummy_skill,
target=dummy_dyn_obs,
)
policy_optimizer_state = self._policy_optimizer.init(policy_params)
critic_optimizer_state = self._critic_optimizer.init(critic_params)
dynamics_optimizer_state = self._dynamics_optimizer.init(dynamics_params)
log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
alpha_optimizer_state = self._alpha_optimizer.init(log_alpha)
return DadsTrainingState(
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
critic_optimizer_state=critic_optimizer_state,
critic_params=critic_params,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=log_alpha,
target_critic_params=target_critic_params,
dynamics_optimizer_state=dynamics_optimizer_state,
dynamics_params=dynamics_params,
random_key=random_key,
normalization_running_stats=RunningMeanStdState(
mean=jnp.zeros(
descriptor_size,
),
var=jnp.ones(
descriptor_size,
),
count=jnp.zeros(()),
),
steps=jnp.array(0),
)
@partial(jax.jit, static_argnames=("self",))
def _compute_diversity_reward(
self, transition: QDTransition, training_state: DadsTrainingState
) -> Reward:
"""Computes the diversity reward of DADS.
Args:
transition: a batch of transitions from the replay buffer
training_state: the current training state
Returns:
the diversity reward
"""
active_skills = transition.obs[:, -self._config.num_skills :]
# Compute dynamics prob
next_state_desc = transition.next_state_desc
state_desc = transition.state_desc
target = next_state_desc - state_desc
if self._config.normalize_target:
target = normalize_with_rmstd(
target, training_state.normalization_running_stats
)
log_q_phi = self._dynamics.apply(
training_state.dynamics_params,
state_desc,
active_skills,
target,
)
# Estimate prior skill
skill_samples = jnp.tile(
jnp.eye(self._config.num_skills), (state_desc.shape[0], 1)
)
state_descriptors = jnp.repeat(state_desc, self._config.num_skills, axis=0)
target = jnp.repeat(target, self._config.num_skills, axis=0)
log_p_s = self._dynamics.apply(
training_state.dynamics_params,
state_descriptors,
skill_samples,
target,
)
log_p_s = log_p_s.reshape((-1, self._config.num_skills))
# Compute the reward according to DADS official implementation
reward = jnp.log(self._config.num_skills) - jnp.log(
jnp.exp(jnp.clip(log_p_s - log_q_phi.reshape((-1, 1)), -50, 50)).sum(axis=1)
)
return reward
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
def play_step_fn(
self,
env_state: EnvState,
training_state: DadsTrainingState,
env: Env,
skills: Skill,
deterministic: bool = False,
evaluation: bool = False,
) -> Tuple[EnvState, DadsTrainingState, QDTransition]:
"""Plays a step in the environment. Concatenates skills to the observation
vector, selects an action according to SAC rule and performs the environment
step.
Args:
env_state: the current environment state
training_state: the DIAYN training state
skills: the skills concatenated to the observation vector
env: the environment
deterministic: the whether or not to select action in a deterministic way.
Defaults to False.
evaluation: if True, collected transitions are not used to update training
state. Defaults to False.
Returns:
the new environment state
the new DADS training state
the played transition
"""
random_key = training_state.random_key
policy_params = training_state.policy_params
obs = jnp.concatenate([env_state.obs, skills], axis=1)
# If the env does not support state descriptor, we set it to (0,0)
if "state_descriptor" in env_state.info:
state_desc = env_state.info["state_descriptor"]
else:
state_desc = jnp.zeros((env_state.obs.shape[0], 2))
actions, random_key = self.select_action(
obs=obs,
policy_params=policy_params,
random_key=random_key,
deterministic=deterministic,
)
next_env_state = env.step(env_state, actions)
next_obs = jnp.concatenate([next_env_state.obs, skills], axis=1)
if "state_descriptor" in next_env_state.info:
next_state_desc = next_env_state.info["state_descriptor"]
else:
next_state_desc = jnp.zeros((next_env_state.obs.shape[0], 2))
if self._config.normalize_target:
if self._config.descriptor_full_state:
_state_desc = obs[:, : -self._config.num_skills]
_next_state_desc = next_obs[:, : -self._config.num_skills]
target = _next_state_desc - _state_desc
else:
target = next_state_desc - state_desc
target *= jnp.expand_dims(1 - next_env_state.done, -1)
normalization_running_stats = update_running_mean_std(
training_state.normalization_running_stats, target
)
else:
normalization_running_stats = training_state.normalization_running_stats
truncations = next_env_state.info["truncation"]
transition = QDTransition(
obs=obs,
next_obs=next_obs,
state_desc=state_desc,
next_state_desc=next_state_desc,
rewards=next_env_state.reward,
dones=next_env_state.done,
actions=actions,
truncations=truncations,
)
if not evaluation:
training_state = training_state.replace(
random_key=random_key,
normalization_running_stats=normalization_running_stats,
)
else:
training_state = training_state.replace(
random_key=random_key,
)
return next_env_state, training_state, transition
@partial(
jax.jit,
static_argnames=(
"self",
"play_step_fn",
"env_batch_size",
),
)
def eval_policy_fn(
self,
training_state: DadsTrainingState,
eval_env_first_state: EnvState,
play_step_fn: Callable[
[EnvState, Params, RNGKey],
Tuple[EnvState, Params, RNGKey, QDTransition],
],
env_batch_size: int,
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
"""Evaluates the agent's policy over an entire episode, across all batched
environments.
Args:
training_state: the DADS training state
eval_env_first_state: the initial state for evaluation
play_step_fn: the play_step function used to collect the evaluation episode
env_batch_size: the number of environments we play simultaneously
Returns:
true return averaged over batch dimension, shape: (1,)
true return per environment, shape: (env_batch_size,)
diversity return per environment, shape: (env_batch_size,)
state descriptors, shape: (episode_length, env_batch_size, descriptor_size)
"""
state, training_state, transitions = generate_unroll(
init_state=eval_env_first_state,
training_state=training_state,
episode_length=self._config.episode_length,
play_step_fn=play_step_fn,
)
transitions = get_first_episode(transitions)
true_returns = jnp.nansum(transitions.rewards, axis=0)
true_return = jnp.mean(true_returns, axis=-1)
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
if self._config.descriptor_full_state:
state_desc = reshaped_transitions.obs[:, : -self._config.num_skills]
next_state_desc = reshaped_transitions.next_obs[
:, : -self._config.num_skills
]
reshaped_transitions = reshaped_transitions.replace(
state_desc=state_desc, next_state_desc=next_state_desc
)
diversity_rewards = self._compute_diversity_reward(
transition=reshaped_transitions, training_state=training_state
).reshape((self._config.episode_length, env_batch_size))
diversity_returns = jnp.nansum(diversity_rewards, axis=0)
return true_return, true_returns, diversity_returns, transitions.state_desc
@partial(jax.jit, static_argnames=("self",))
def _compute_reward(
self, transition: QDTransition, training_state: DadsTrainingState
) -> Reward:
"""Computes the reward to train the networks.
Args:
transition: a batch of transitions from the replay buffer
training_state: the current training state
Returns:
the DADS diversity reward
"""
return self._compute_diversity_reward(
transition=transition, training_state=training_state
)
@partial(jax.jit, static_argnames=("self",))
def _update_dynamics(
self, operand: Tuple[DadsTrainingState, QDTransition]
) -> Tuple[Params, float, optax.OptState]:
"""Update the dynamics network, independently of other networks. Called every
`dynamics_update_freq` training steps.
"""
training_state, transitions = operand
dynamics_loss, dynamics_gradient = jax.value_and_grad(self._dynamics_loss_fn,)(
training_state.dynamics_params,
transitions=transitions,
)
(dynamics_updates, dynamics_optimizer_state,) = self._dynamics_optimizer.update(
dynamics_gradient, training_state.dynamics_optimizer_state
)
dynamics_params = optax.apply_updates(
training_state.dynamics_params, dynamics_updates
)
return (
dynamics_params,
dynamics_loss,
dynamics_optimizer_state,
)
@partial(jax.jit, static_argnames=("self",))
def _not_update_dynamics(
self, operand: Tuple[DadsTrainingState, QDTransition]
) -> Tuple[Params, float, optax.OptState]:
"""Fake update of the dynamics, called every time we don't want to update
the dynamics while we update the other networks.
"""
training_state, _transitions = operand
return (
training_state.dynamics_params,
jnp.nan,
training_state.dynamics_optimizer_state,
)
@partial(jax.jit, static_argnames=("self",))
def _update_networks(
self,
training_state: DadsTrainingState,
transitions: QDTransition,
) -> Tuple[DadsTrainingState, Metrics]:
"""Updates the networks involved in DADS.
Args:
training_state: the current training state of the algorithm.
transitions: transitions sampled from a replay buffer.
random_key: a random key to handle stochasticity.
Returns:
The updated training state and training metrics.
"""
random_key = training_state.random_key
# Update skill-dynamics
(dynamics_params, dynamics_loss, dynamics_optimizer_state,) = jax.lax.cond(
training_state.steps % self._config.dynamics_update_freq == 0,
self._update_dynamics,
self._not_update_dynamics,
(training_state, transitions),
)
# udpate alpha
(
alpha_params,
alpha_optimizer_state,
alpha_loss,
random_key,
) = self._update_alpha(
alpha_lr=self._config.learning_rate,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update critic
(
critic_params,
target_critic_params,
critic_optimizer_state,
critic_loss,
random_key,
) = self._update_critic(
critic_lr=self._config.learning_rate,
reward_scaling=self._config.reward_scaling,
discount=self._config.discount,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# update actor
(
policy_params,
policy_optimizer_state,
policy_loss,
random_key,
) = self._update_actor(
policy_lr=self._config.learning_rate,
training_state=training_state,
transitions=transitions,
random_key=random_key,
)
# Create new training state
new_training_state = DadsTrainingState(
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
critic_optimizer_state=critic_optimizer_state,
critic_params=critic_params,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=alpha_params,
target_critic_params=target_critic_params,
dynamics_optimizer_state=dynamics_optimizer_state,
dynamics_params=dynamics_params,
random_key=random_key,
normalization_running_stats=training_state.normalization_running_stats,
steps=training_state.steps + 1,
)
metrics = {
"actor_loss": policy_loss,
"critic_loss": critic_loss,
"dynamics_loss": dynamics_loss,
"alpha_loss": alpha_loss,
"alpha": jnp.exp(alpha_params),
"training_observed_reward_mean": jnp.mean(transitions.rewards),
"target_mean": jnp.mean(transitions.next_state_desc),
"target_std": jnp.std(transitions.next_state_desc),
}
return new_training_state, metrics
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DadsTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[DadsTrainingState, ReplayBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
dynamics network parameters.
Args:
training_state: the current DADS training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
transitions, random_key = replay_buffer.sample(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
_state_desc = transitions.obs[:, : -self._config.num_skills]
_next_state_desc = transitions.next_obs[:, : -self._config.num_skills]
transitions = transitions.replace(
state_desc=_state_desc, next_state_desc=_next_state_desc
)
# Compute the reward
rewards = self._compute_reward(
transition=transitions, training_state=training_state
)
# Compute the target and optionally normalize it for the training
if self._config.normalize_target:
next_state_desc = normalize_with_rmstd(
transitions.next_state_desc - transitions.state_desc,
training_state.normalization_running_stats,
)
else:
next_state_desc = transitions.next_state_desc - transitions.state_desc
# Update the transitions
transitions = transitions.replace(
next_state_desc=next_state_desc, rewards=rewards
)
new_training_state, metrics = self._update_networks(
training_state, transitions=transitions
)
return new_training_state, replay_buffer, metrics
init(self, random_key, action_size, observation_size, descriptor_size)
¶
Initialise the training state of the algorithm.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/dads.py
def init( # type: ignore
self,
random_key: RNGKey,
action_size: int,
observation_size: int,
descriptor_size: int,
) -> DadsTrainingState:
"""Initialise the training state of the algorithm.
Args:
random_key: a jax random key
action_size: the size of the environment's action space
observation_size: the size of the environment's observation space
descriptor_size: the size of the environment's descriptor space (i.e. the
dimension of the dynamics network's input)
Returns:
the initial training state of DADS
"""
# Initialize params
dummy_obs = jnp.zeros((1, observation_size + self._config.num_skills))
dummy_action = jnp.zeros((1, action_size))
dummy_dyn_obs = jnp.zeros((1, descriptor_size))
dummy_skill = jnp.zeros((1, self._config.num_skills))
random_key, subkey = jax.random.split(random_key)
policy_params = self._policy.init(subkey, dummy_obs)
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)
random_key, subkey = jax.random.split(random_key)
dynamics_params = self._dynamics.init(
subkey,
obs=dummy_dyn_obs,
skill=dummy_skill,
target=dummy_dyn_obs,
)
policy_optimizer_state = self._policy_optimizer.init(policy_params)
critic_optimizer_state = self._critic_optimizer.init(critic_params)
dynamics_optimizer_state = self._dynamics_optimizer.init(dynamics_params)
log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
alpha_optimizer_state = self._alpha_optimizer.init(log_alpha)
return DadsTrainingState(
policy_optimizer_state=policy_optimizer_state,
policy_params=policy_params,
critic_optimizer_state=critic_optimizer_state,
critic_params=critic_params,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=log_alpha,
target_critic_params=target_critic_params,
dynamics_optimizer_state=dynamics_optimizer_state,
dynamics_params=dynamics_params,
random_key=random_key,
normalization_running_stats=RunningMeanStdState(
mean=jnp.zeros(
descriptor_size,
),
var=jnp.ones(
descriptor_size,
),
count=jnp.zeros(()),
),
steps=jnp.array(0),
)
play_step_fn(self, env_state, training_state, env, skills, deterministic=False, evaluation=False)
¶
Plays a step in the environment. Concatenates skills to the observation vector, selects an action according to SAC rule and performs the environment step.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/dads.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
def play_step_fn(
self,
env_state: EnvState,
training_state: DadsTrainingState,
env: Env,
skills: Skill,
deterministic: bool = False,
evaluation: bool = False,
) -> Tuple[EnvState, DadsTrainingState, QDTransition]:
"""Plays a step in the environment. Concatenates skills to the observation
vector, selects an action according to SAC rule and performs the environment
step.
Args:
env_state: the current environment state
training_state: the DIAYN training state
skills: the skills concatenated to the observation vector
env: the environment
deterministic: the whether or not to select action in a deterministic way.
Defaults to False.
evaluation: if True, collected transitions are not used to update training
state. Defaults to False.
Returns:
the new environment state
the new DADS training state
the played transition
"""
random_key = training_state.random_key
policy_params = training_state.policy_params
obs = jnp.concatenate([env_state.obs, skills], axis=1)
# If the env does not support state descriptor, we set it to (0,0)
if "state_descriptor" in env_state.info:
state_desc = env_state.info["state_descriptor"]
else:
state_desc = jnp.zeros((env_state.obs.shape[0], 2))
actions, random_key = self.select_action(
obs=obs,
policy_params=policy_params,
random_key=random_key,
deterministic=deterministic,
)
next_env_state = env.step(env_state, actions)
next_obs = jnp.concatenate([next_env_state.obs, skills], axis=1)
if "state_descriptor" in next_env_state.info:
next_state_desc = next_env_state.info["state_descriptor"]
else:
next_state_desc = jnp.zeros((next_env_state.obs.shape[0], 2))
if self._config.normalize_target:
if self._config.descriptor_full_state:
_state_desc = obs[:, : -self._config.num_skills]
_next_state_desc = next_obs[:, : -self._config.num_skills]
target = _next_state_desc - _state_desc
else:
target = next_state_desc - state_desc
target *= jnp.expand_dims(1 - next_env_state.done, -1)
normalization_running_stats = update_running_mean_std(
training_state.normalization_running_stats, target
)
else:
normalization_running_stats = training_state.normalization_running_stats
truncations = next_env_state.info["truncation"]
transition = QDTransition(
obs=obs,
next_obs=next_obs,
state_desc=state_desc,
next_state_desc=next_state_desc,
rewards=next_env_state.reward,
dones=next_env_state.done,
actions=actions,
truncations=truncations,
)
if not evaluation:
training_state = training_state.replace(
random_key=random_key,
normalization_running_stats=normalization_running_stats,
)
else:
training_state = training_state.replace(
random_key=random_key,
)
return next_env_state, training_state, transition
eval_policy_fn(self, training_state, eval_env_first_state, play_step_fn, env_batch_size)
¶
Evaluates the agent's policy over an entire episode, across all batched environments.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/dads.py
@partial(
jax.jit,
static_argnames=(
"self",
"play_step_fn",
"env_batch_size",
),
)
def eval_policy_fn(
self,
training_state: DadsTrainingState,
eval_env_first_state: EnvState,
play_step_fn: Callable[
[EnvState, Params, RNGKey],
Tuple[EnvState, Params, RNGKey, QDTransition],
],
env_batch_size: int,
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
"""Evaluates the agent's policy over an entire episode, across all batched
environments.
Args:
training_state: the DADS training state
eval_env_first_state: the initial state for evaluation
play_step_fn: the play_step function used to collect the evaluation episode
env_batch_size: the number of environments we play simultaneously
Returns:
true return averaged over batch dimension, shape: (1,)
true return per environment, shape: (env_batch_size,)
diversity return per environment, shape: (env_batch_size,)
state descriptors, shape: (episode_length, env_batch_size, descriptor_size)
"""
state, training_state, transitions = generate_unroll(
init_state=eval_env_first_state,
training_state=training_state,
episode_length=self._config.episode_length,
play_step_fn=play_step_fn,
)
transitions = get_first_episode(transitions)
true_returns = jnp.nansum(transitions.rewards, axis=0)
true_return = jnp.mean(true_returns, axis=-1)
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
if self._config.descriptor_full_state:
state_desc = reshaped_transitions.obs[:, : -self._config.num_skills]
next_state_desc = reshaped_transitions.next_obs[
:, : -self._config.num_skills
]
reshaped_transitions = reshaped_transitions.replace(
state_desc=state_desc, next_state_desc=next_state_desc
)
diversity_rewards = self._compute_diversity_reward(
transition=reshaped_transitions, training_state=training_state
).reshape((self._config.episode_length, env_batch_size))
diversity_returns = jnp.nansum(diversity_rewards, axis=0)
return true_return, true_returns, diversity_returns, transitions.state_desc
update(self, training_state, replay_buffer)
¶
Performs a training step to update the policy, the critic and the dynamics network parameters.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/dads.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DadsTrainingState,
replay_buffer: ReplayBuffer,
) -> Tuple[DadsTrainingState, ReplayBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
dynamics network parameters.
Args:
training_state: the current DADS training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
transitions, random_key = replay_buffer.sample(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
_state_desc = transitions.obs[:, : -self._config.num_skills]
_next_state_desc = transitions.next_obs[:, : -self._config.num_skills]
transitions = transitions.replace(
state_desc=_state_desc, next_state_desc=_next_state_desc
)
# Compute the reward
rewards = self._compute_reward(
transition=transitions, training_state=training_state
)
# Compute the target and optionally normalize it for the training
if self._config.normalize_target:
next_state_desc = normalize_with_rmstd(
transitions.next_state_desc - transitions.state_desc,
training_state.normalization_running_stats,
)
else:
next_state_desc = transitions.next_state_desc - transitions.state_desc
# Update the transitions
transitions = transitions.replace(
next_state_desc=next_state_desc, rewards=rewards
)
new_training_state, metrics = self._update_networks(
training_state, transitions=transitions
)
return new_training_state, replay_buffer, metrics