SMERL classes for DIAYN and DADS

Bases: DIAYN

DIAYNSMERL refers to a family of methods that combine the DIAYN's diversity reward with some environment extrinsic reward, using SMERL method, see https://arxiv.org/abs/2010.14484.

Most methods are inherited from the DIAYN algorithm, the only change is the way the reward is computed (a combination of the DIAYN reward and the extrinsic reward).

Source code in qdax/baselines/diayn_smerl.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class DIAYNSMERL(DIAYN):
    """DIAYNSMERL refers to a family of methods that combine the DIAYN's diversity
    reward with some environment `extrinsic` reward, using SMERL method, see
    https://arxiv.org/abs/2010.14484.

    Most methods are inherited from the DIAYN algorithm, the only change is the
    way the reward is computed (a combination of the DIAYN reward and
    the `extrinsic` reward).
    """

    def __init__(self, config: DiaynSmerlConfig, action_size: int):
        super(DIAYNSMERL, self).__init__(config, action_size)
        self._config: DiaynSmerlConfig = config

    def _compute_reward(  # type: ignore
        self,
        transition: QDTransition,
        training_state: DiaynTrainingState,
        returns: Reward,
    ) -> Reward:
        """Computes the reward to train the networks.

        Args:
            transition: a batch of transitions from the replay buffer
            training_state: the current training state
            returns: an array containing the episode's return for every sample

        Returns:
            the combined reward
        """

        # Compute diversity reward
        diversity_rewards = self._compute_diversity_reward(
            transition=transition,
            discriminator_params=training_state.discriminator_params,
            add_log_p_z=True,
        )

        # Compute SMERL reward
        assert (
            self._config.smerl_target is not None
            and self._config.smerl_margin is not None
        ), "Missing SMERL target and margin values"

        # is the return good enough to consider the diversity reward
        accept = returns >= self._config.smerl_target - self._config.smerl_margin

        # compute the new reward (r_extrinsic + accept * diversity_scale * r_diversity)
        rewards = (
            transition.rewards
            + accept * self._config.diversity_reward_scale * diversity_rewards
        )

        return rewards

    def update(
        self,
        training_state: DiaynTrainingState,
        replay_buffer: TrajectoryBuffer,
    ) -> Tuple[DiaynTrainingState, TrajectoryBuffer, Metrics]:
        """Performs a training step to update the policy, the critic and the
        discriminator parameters.

        Args:
            training_state: the current DIAYN training state
            replay_buffer: the replay buffer

        Returns:
            the updated DIAYN training state
            the replay buffer
            the training metrics
        """
        key = training_state.key

        # Sample a batch of transitions in the buffer
        key, subkey = jax.random.split(key)
        samples, returns = replay_buffer.sample_with_returns(
            subkey,
            sample_size=self._config.batch_size,
        )

        training_state = training_state.replace(key=key)

        # Optionally replace the state descriptor by the observation
        if self._config.descriptor_full_state:
            state_desc = samples.obs[:, : -self._config.num_skills]
            next_state_desc = samples.next_obs[:, : -self._config.num_skills]
            samples = samples.replace(
                state_desc=state_desc, next_state_desc=next_state_desc
            )

        # Compute the rewards
        rewards = self._compute_reward(samples, training_state, returns)

        samples = samples.replace(rewards=rewards)

        new_training_state, metrics = self._update_networks(
            training_state, transitions=samples
        )

        return new_training_state, replay_buffer, metrics

update(training_state, replay_buffer)

Performs a training step to update the policy, the critic and the discriminator parameters.

Parameters:
  • training_state (DiaynTrainingState) –

    the current DIAYN training state

  • replay_buffer (TrajectoryBuffer) –

    the replay buffer

Returns:
  • DiaynTrainingState

    the updated DIAYN training state

  • TrajectoryBuffer

    the replay buffer

  • Metrics

    the training metrics

Source code in qdax/baselines/diayn_smerl.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def update(
    self,
    training_state: DiaynTrainingState,
    replay_buffer: TrajectoryBuffer,
) -> Tuple[DiaynTrainingState, TrajectoryBuffer, Metrics]:
    """Performs a training step to update the policy, the critic and the
    discriminator parameters.

    Args:
        training_state: the current DIAYN training state
        replay_buffer: the replay buffer

    Returns:
        the updated DIAYN training state
        the replay buffer
        the training metrics
    """
    key = training_state.key

    # Sample a batch of transitions in the buffer
    key, subkey = jax.random.split(key)
    samples, returns = replay_buffer.sample_with_returns(
        subkey,
        sample_size=self._config.batch_size,
    )

    training_state = training_state.replace(key=key)

    # Optionally replace the state descriptor by the observation
    if self._config.descriptor_full_state:
        state_desc = samples.obs[:, : -self._config.num_skills]
        next_state_desc = samples.next_obs[:, : -self._config.num_skills]
        samples = samples.replace(
            state_desc=state_desc, next_state_desc=next_state_desc
        )

    # Compute the rewards
    rewards = self._compute_reward(samples, training_state, returns)

    samples = samples.replace(rewards=rewards)

    new_training_state, metrics = self._update_networks(
        training_state, transitions=samples
    )

    return new_training_state, replay_buffer, metrics

Bases: DADS

DADSSMERL refers to a family of methods that combine the DADS's diversity reward with some environment extrinsic reward, using the proper SMERL method, see https://arxiv.org/abs/2010.14484.

Most of the methods are inherited from the DADS algorithm, the only change is the way the reward is computed (a combination of the DADS reward and the extrinsic reward).

Source code in qdax/baselines/dads_smerl.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class DADSSMERL(DADS):
    """DADSSMERL refers to a family of methods that combine the DADS's diversity
    reward with some environment `extrinsic` reward, using the proper SMERL method,
    see https://arxiv.org/abs/2010.14484.

    Most of the methods are inherited from the DADS algorithm, the only change is
    the way the reward is computed (a combination of the DADS reward and the `extrinsic`
    reward).
    """

    def __init__(self, config: DadsSmerlConfig, action_size: int, descriptor_size: int):
        super(DADSSMERL, self).__init__(config, action_size, descriptor_size)
        self._config: DadsSmerlConfig = config

    def _compute_reward(  # type: ignore
        self,
        transition: QDTransition,
        training_state: DadsTrainingState,
        returns: Reward,
    ) -> Reward:
        """Computes the reward to train the networks.

        Args:
            transition: a batch of transitions from the replay buffer
            training_state: the current training state

        Returns:
            the reward
        """

        diversity_rewards = self._compute_diversity_reward(
            transition=transition, training_state=training_state
        )
        # Compute SMERL reward (r_extrinsic + accept * diversity_scale * r_diversity)
        assert (
            self._config.smerl_target is not None
            and self._config.smerl_margin is not None
        ), "Missing SMERL target and margin values"

        accept = returns >= self._config.smerl_target - self._config.smerl_margin
        rewards = (
            transition.rewards
            + accept * self._config.diversity_reward_scale * diversity_rewards
        )

        return rewards

    def update(
        self,
        training_state: DadsTrainingState,
        replay_buffer: TrajectoryBuffer,
    ) -> Tuple[DadsTrainingState, TrajectoryBuffer, Metrics]:
        """Performs a training step to update the policy, the critic and the
        dynamics network parameters.

        Args:
            training_state: the current DADS training state
            replay_buffer: the replay buffer

        Returns:
            the updated DIAYN training state
            the replay buffer
            the training metrics
        """
        key = training_state.key

        # Sample a batch of transitions in the buffer
        key, subkey = jax.random.split(key)
        samples, returns = replay_buffer.sample_with_returns(
            subkey,
            sample_size=self._config.batch_size,
        )

        training_state = training_state.replace(key=key)

        # Optionally replace the state descriptor by the observation
        if self._config.descriptor_full_state:
            _state_desc = samples.obs[:, : -self._config.num_skills]
            _next_state_desc = samples.next_obs[:, : -self._config.num_skills]
            samples = samples.replace(
                state_desc=_state_desc, next_state_desc=_next_state_desc
            )

        # Compute the reward
        rewards = self._compute_reward(
            transition=samples, training_state=training_state, returns=returns
        )

        # Compute the target and optionally normalize it for the training
        if self._config.normalize_target:
            next_state_desc = normalize_with_rmstd(
                samples.next_state_desc - samples.state_desc,
                training_state.normalization_running_stats,
            )

        else:
            next_state_desc = samples.next_state_desc - samples.state_desc

        # Update the transitions
        samples = samples.replace(next_state_desc=next_state_desc, rewards=rewards)

        new_training_state, metrics = self._update_networks(
            training_state, transitions=samples
        )
        return new_training_state, replay_buffer, metrics

update(training_state, replay_buffer)

Performs a training step to update the policy, the critic and the dynamics network parameters.

Parameters:
  • training_state (DadsTrainingState) –

    the current DADS training state

  • replay_buffer (TrajectoryBuffer) –

    the replay buffer

Returns:
  • DadsTrainingState

    the updated DIAYN training state

  • TrajectoryBuffer

    the replay buffer

  • Metrics

    the training metrics

Source code in qdax/baselines/dads_smerl.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def update(
    self,
    training_state: DadsTrainingState,
    replay_buffer: TrajectoryBuffer,
) -> Tuple[DadsTrainingState, TrajectoryBuffer, Metrics]:
    """Performs a training step to update the policy, the critic and the
    dynamics network parameters.

    Args:
        training_state: the current DADS training state
        replay_buffer: the replay buffer

    Returns:
        the updated DIAYN training state
        the replay buffer
        the training metrics
    """
    key = training_state.key

    # Sample a batch of transitions in the buffer
    key, subkey = jax.random.split(key)
    samples, returns = replay_buffer.sample_with_returns(
        subkey,
        sample_size=self._config.batch_size,
    )

    training_state = training_state.replace(key=key)

    # Optionally replace the state descriptor by the observation
    if self._config.descriptor_full_state:
        _state_desc = samples.obs[:, : -self._config.num_skills]
        _next_state_desc = samples.next_obs[:, : -self._config.num_skills]
        samples = samples.replace(
            state_desc=_state_desc, next_state_desc=_next_state_desc
        )

    # Compute the reward
    rewards = self._compute_reward(
        transition=samples, training_state=training_state, returns=returns
    )

    # Compute the target and optionally normalize it for the training
    if self._config.normalize_target:
        next_state_desc = normalize_with_rmstd(
            samples.next_state_desc - samples.state_desc,
            training_state.normalization_running_stats,
        )

    else:
        next_state_desc = samples.next_state_desc - samples.state_desc

    # Update the transitions
    samples = samples.replace(next_state_desc=next_state_desc, rewards=rewards)

    new_training_state, metrics = self._update_networks(
        training_state, transitions=samples
    )
    return new_training_state, replay_buffer, metrics