SMERL classes for DIAYN and DADS¶
qdax.baselines.diayn_smerl.DIAYNSMERL (DIAYN)
¶
DIAYNSMERL refers to a family of methods that combine the DIAYN's diversity
reward with some environment extrinsic
reward, using SMERL method, see
https://arxiv.org/abs/2010.14484.
Most methods are inherited from the DIAYN algorithm, the only change is the
way the reward is computed (a combination of the DIAYN reward and
the extrinsic
reward).
Source code in qdax/baselines/diayn_smerl.py
class DIAYNSMERL(DIAYN):
"""DIAYNSMERL refers to a family of methods that combine the DIAYN's diversity
reward with some environment `extrinsic` reward, using SMERL method, see
https://arxiv.org/abs/2010.14484.
Most methods are inherited from the DIAYN algorithm, the only change is the
way the reward is computed (a combination of the DIAYN reward and
the `extrinsic` reward).
"""
def __init__(self, config: DiaynSmerlConfig, action_size: int):
super(DIAYNSMERL, self).__init__(config, action_size)
self._config: DiaynSmerlConfig = config
@partial(jax.jit, static_argnames=("self",))
def _compute_reward(
self,
transition: QDTransition,
training_state: DiaynTrainingState,
returns: Reward,
) -> Reward:
"""Computes the reward to train the networks.
Args:
transition: a batch of transitions from the replay buffer
training_state: the current training state
returns: an array containing the episode's return for every sample
Returns:
the combined reward
"""
# Compute diversity reward
diversity_rewards = self._compute_diversity_reward(
transition=transition,
discriminator_params=training_state.discriminator_params,
add_log_p_z=True,
)
# Compute SMERL reward
assert (
self._config.smerl_target is not None
and self._config.smerl_margin is not None
), "Missing SMERL target and margin values"
# is the return good enough to consider the diversity reward
accept = returns >= self._config.smerl_target - self._config.smerl_margin
# compute the new reward (r_extrinsic + accept * diversity_scale * r_diversity)
rewards = (
transition.rewards
+ accept * self._config.diversity_reward_scale * diversity_rewards
)
return rewards
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DiaynTrainingState,
replay_buffer: TrajectoryBuffer,
) -> Tuple[DiaynTrainingState, TrajectoryBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
discriminator parameters.
Args:
training_state: the current DIAYN training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, returns, random_key = replay_buffer.sample_with_returns(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
state_desc = samples.obs[:, : -self._config.num_skills]
next_state_desc = samples.next_obs[:, : -self._config.num_skills]
samples = samples.replace(
state_desc=state_desc, next_state_desc=next_state_desc
)
# Compute the rewards
rewards = self._compute_reward(samples, training_state, returns)
samples = samples.replace(rewards=rewards)
new_training_state, metrics = self._update_networks(
training_state, transitions=samples
)
return new_training_state, replay_buffer, metrics
update(self, training_state, replay_buffer)
¶
Performs a training step to update the policy, the critic and the discriminator parameters.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/diayn_smerl.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DiaynTrainingState,
replay_buffer: TrajectoryBuffer,
) -> Tuple[DiaynTrainingState, TrajectoryBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
discriminator parameters.
Args:
training_state: the current DIAYN training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, returns, random_key = replay_buffer.sample_with_returns(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
state_desc = samples.obs[:, : -self._config.num_skills]
next_state_desc = samples.next_obs[:, : -self._config.num_skills]
samples = samples.replace(
state_desc=state_desc, next_state_desc=next_state_desc
)
# Compute the rewards
rewards = self._compute_reward(samples, training_state, returns)
samples = samples.replace(rewards=rewards)
new_training_state, metrics = self._update_networks(
training_state, transitions=samples
)
return new_training_state, replay_buffer, metrics
qdax.baselines.dads_smerl.DADSSMERL (DADS)
¶
DADSSMERL refers to a family of methods that combine the DADS's diversity
reward with some environment extrinsic
reward, using the proper SMERL method,
see https://arxiv.org/abs/2010.14484.
Most of the methods are inherited from the DADS algorithm, the only change is
the way the reward is computed (a combination of the DADS reward and the extrinsic
reward).
Source code in qdax/baselines/dads_smerl.py
class DADSSMERL(DADS):
"""DADSSMERL refers to a family of methods that combine the DADS's diversity
reward with some environment `extrinsic` reward, using the proper SMERL method,
see https://arxiv.org/abs/2010.14484.
Most of the methods are inherited from the DADS algorithm, the only change is
the way the reward is computed (a combination of the DADS reward and the `extrinsic`
reward).
"""
def __init__(self, config: DadsSmerlConfig, action_size: int, descriptor_size: int):
super(DADSSMERL, self).__init__(config, action_size, descriptor_size)
self._config: DadsSmerlConfig = config
@partial(jax.jit, static_argnames=("self",))
def _compute_reward(
self,
transition: QDTransition,
training_state: DadsTrainingState,
returns: Reward,
) -> Reward:
"""Computes the reward to train the networks.
Args:
transition: a batch of transitions from the replay buffer
training_state: the current training state
Returns:
the reward
"""
diversity_rewards = self._compute_diversity_reward(
transition=transition, training_state=training_state
)
# Compute SMERL reward (r_extrinsic + accept * diversity_scale * r_diversity)
assert (
self._config.smerl_target is not None
and self._config.smerl_margin is not None
), "Missing SMERL target and margin values"
accept = returns >= self._config.smerl_target - self._config.smerl_margin
rewards = (
transition.rewards
+ accept * self._config.diversity_reward_scale * diversity_rewards
)
return rewards
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DadsTrainingState,
replay_buffer: TrajectoryBuffer,
) -> Tuple[DadsTrainingState, TrajectoryBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
dynamics network parameters.
Args:
training_state: the current DADS training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, returns, random_key = replay_buffer.sample_with_returns(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
_state_desc = samples.obs[:, : -self._config.num_skills]
_next_state_desc = samples.next_obs[:, : -self._config.num_skills]
samples = samples.replace(
state_desc=_state_desc, next_state_desc=_next_state_desc
)
# Compute the reward
rewards = self._compute_reward(
transition=samples, training_state=training_state, returns=returns
)
# Compute the target and optionally normalize it for the training
if self._config.normalize_target:
next_state_desc = normalize_with_rmstd(
samples.next_state_desc - samples.state_desc,
training_state.normalization_running_stats,
)
else:
next_state_desc = samples.next_state_desc - samples.state_desc
# Update the transitions
samples = samples.replace(next_state_desc=next_state_desc, rewards=rewards)
new_training_state, metrics = self._update_networks(
training_state, transitions=samples
)
return new_training_state, replay_buffer, metrics
update(self, training_state, replay_buffer)
¶
Performs a training step to update the policy, the critic and the dynamics network parameters.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/baselines/dads_smerl.py
@partial(jax.jit, static_argnames=("self",))
def update(
self,
training_state: DadsTrainingState,
replay_buffer: TrajectoryBuffer,
) -> Tuple[DadsTrainingState, TrajectoryBuffer, Metrics]:
"""Performs a training step to update the policy, the critic and the
dynamics network parameters.
Args:
training_state: the current DADS training state
replay_buffer: the replay buffer
Returns:
the updated DIAYN training state
the replay buffer
the training metrics
"""
# Sample a batch of transitions in the buffer
random_key = training_state.random_key
samples, returns, random_key = replay_buffer.sample_with_returns(
random_key,
sample_size=self._config.batch_size,
)
# Optionally replace the state descriptor by the observation
if self._config.descriptor_full_state:
_state_desc = samples.obs[:, : -self._config.num_skills]
_next_state_desc = samples.next_obs[:, : -self._config.num_skills]
samples = samples.replace(
state_desc=_state_desc, next_state_desc=_next_state_desc
)
# Compute the reward
rewards = self._compute_reward(
transition=samples, training_state=training_state, returns=returns
)
# Compute the target and optionally normalize it for the training
if self._config.normalize_target:
next_state_desc = normalize_with_rmstd(
samples.next_state_desc - samples.state_desc,
training_state.normalization_running_stats,
)
else:
next_state_desc = samples.next_state_desc - samples.state_desc
# Update the transitions
samples = samples.replace(next_state_desc=next_state_desc, rewards=rewards)
new_training_state, metrics = self._update_networks(
training_state, transitions=samples
)
return new_training_state, replay_buffer, metrics