DIAYN class

qdax.baselines.diayn.DIAYN (SAC)

Implements DIAYN algorithm https://arxiv.org/abs/1802.06070.

Note that the functions select_action, _update_alpha, _update_critic and _update_actor are inherited from SAC algorithm.

In the current implementation, we suppose that the skills are fixed one hot vectors, and do not support continuous skills at the moment.

Also, we suppose that the skills are evaluated in parallel in a fixed manner: a batch of environments, containing a multiple of the number of skills, is used to evaluate the skills in the environment and hence to generate transitions. The sampling is hence fixed and perfectly uniform.

Since we are using categorical skills, the current loss function used to train the discriminator is the categorical cross entropy loss.

We plan to add continous skill as an option in the future. We also plan to release the current constraint on the number of batched environments by sampling from the skills rather than having this fixed setting.

Source code in qdax/baselines/diayn.py
class DIAYN(SAC):
    """Implements DIAYN algorithm https://arxiv.org/abs/1802.06070.

    Note that the functions select_action, _update_alpha, _update_critic and
    _update_actor are inherited from SAC algorithm.

    In the current implementation, we suppose that the skills are fixed one
    hot vectors, and do not support continuous skills at the moment.

    Also, we suppose that the skills are evaluated in parallel in a fixed
    manner: a batch of environments, containing a multiple of the number
    of skills, is used to evaluate the skills in the environment and hence
    to generate transitions. The sampling is hence fixed and perfectly uniform.

    Since we are using categorical skills, the current loss function used
    to train the discriminator is the categorical cross entropy loss.

    We plan to add continous skill as an option in the future. We also plan
    to release the current constraint on the number of batched environments
    by sampling from the skills rather than having this fixed setting.
    """

    def __init__(self, config: DiaynConfig, action_size: int):
        self._config: DiaynConfig = config
        if self._config.normalize_observations:
            raise NotImplementedError("Normalization is not implemented for DIAYN yet")

        # define the networks
        self._policy, self._critic, self._discriminator = make_diayn_networks(
            num_skills=self._config.num_skills,
            action_size=action_size,
            policy_hidden_layer_size=self._config.policy_hidden_layer_size,
            critic_hidden_layer_size=self._config.critic_hidden_layer_size,
        )

        # define the action distribution
        self._action_size = action_size
        self._parametric_action_distribution = NormalTanhDistribution(
            event_size=action_size
        )
        self._sample_action_fn = self._parametric_action_distribution.sample

        # define the losses
        (
            self._alpha_loss_fn,
            self._policy_loss_fn,
            self._critic_loss_fn,
            self._discriminator_loss_fn,
        ) = make_diayn_loss_fn(
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            discriminator_fn=self._discriminator.apply,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            action_size=action_size,
            num_skills=self._config.num_skills,
            parametric_action_distribution=self._parametric_action_distribution,
        )

        # define the optimizers
        self._policy_optimizer = optax.adam(learning_rate=self._config.learning_rate)
        self._critic_optimizer = optax.adam(learning_rate=self._config.learning_rate)
        self._alpha_optimizer = optax.adam(learning_rate=self._config.learning_rate)
        self._discriminator_optimizer = optax.adam(
            learning_rate=self._config.learning_rate
        )

    def init(  # type: ignore
        self,
        random_key: RNGKey,
        action_size: int,
        observation_size: int,
        descriptor_size: int,
    ) -> DiaynTrainingState:
        """Initialise the training state of the algorithm.

        Args:
            random_key: a jax random key
            action_size: the size of the environment's action space
            observation_size: the size of the environment's observation space
            descriptor_size: the size of the environment's descriptor space (i.e. the
                dimension of the discriminator's input)

        Returns:
            the initial training state of DIAYN
        """

        # define policy and critic params
        dummy_obs = jnp.zeros((1, observation_size + self._config.num_skills))
        dummy_action = jnp.zeros((1, action_size))
        dummy_discriminator_obs = jnp.zeros((1, descriptor_size))

        random_key, subkey = jax.random.split(random_key)
        policy_params = self._policy.init(subkey, dummy_obs)

        random_key, subkey = jax.random.split(random_key)
        critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

        target_critic_params = jax.tree_util.tree_map(
            lambda x: jnp.asarray(x.copy()), critic_params
        )

        random_key, subkey = jax.random.split(random_key)
        discriminator_params = self._discriminator.init(
            subkey, obs=dummy_discriminator_obs
        )

        policy_optimizer_state = self._policy_optimizer.init(policy_params)
        critic_optimizer_state = self._critic_optimizer.init(critic_params)
        discriminator_optimizer_state = self._discriminator_optimizer.init(
            discriminator_params
        )

        log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
        alpha_optimizer_state = self._alpha_optimizer.init(log_alpha)

        return DiaynTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=log_alpha,
            target_critic_params=target_critic_params,
            discriminator_optimizer_state=discriminator_optimizer_state,
            discriminator_params=discriminator_params,
            random_key=random_key,
            steps=jnp.array(0),
        )

    @partial(jax.jit, static_argnames=("self", "add_log_p_z"))
    def _compute_diversity_reward(
        self,
        transition: QDTransition,
        discriminator_params: Params,
        add_log_p_z: bool = True,
    ) -> Reward:
        """Computes the diversity reward of DIAYN.

        As we use discrete skills (one hot encoded vectors), the categorical
        cross entropy is used here.

        Args:
            transition: a batch of transitions from the replay buffer
            discriminator_params: the parameters of the discriminator
            add_log_p_z: whether or not to add (minus) the probability of the skills'
                prior distribution. Defaults to True.

        Returns:
            the diversity reward
        """

        next_state_desc = transition.next_state_desc
        skills = transition.next_obs[:, -self._config.num_skills :]
        reward = jnp.sum(
            jax.nn.log_softmax(
                self._discriminator.apply(discriminator_params, next_state_desc)
            )
            * skills,
            axis=1,
        )
        if add_log_p_z:
            reward += jnp.log(self._config.num_skills)
        return reward

    @partial(jax.jit, static_argnames=("self", "env", "deterministic"))
    def play_step_fn(
        self,
        env_state: EnvState,
        training_state: DiaynTrainingState,
        skills: Skill,
        env: Env,
        deterministic: bool = False,
    ) -> Tuple[EnvState, DiaynTrainingState, QDTransition]:
        """Plays a step in the environment. Concatenates skills to the observation
        vector, selects an action according to SAC rule and performs the environment
        step.

        Args:
            env_state: the current environment state
            training_state: the DIAYN training state
            skills: the skills concatenated to the observation vector
            env: the environment
            deterministic: whether or not to select action in a deterministic way.
                Defaults to False.

        Returns:
            the new environment state
            the new DIAYN training state
            the played transition
        """

        random_key = training_state.random_key
        policy_params = training_state.policy_params
        obs = jnp.concatenate([env_state.obs, skills], axis=1)

        # If the env does not support state descriptor, we set it to (0,0)
        if "state_descriptor" in env_state.info:
            state_desc = env_state.info["state_descriptor"]
        else:
            state_desc = jnp.zeros((env_state.obs.shape[0], 2))

        actions, random_key = self.select_action(
            obs=obs,
            policy_params=policy_params,
            random_key=random_key,
            deterministic=deterministic,
        )

        next_env_state = env.step(env_state, actions)
        next_obs = jnp.concatenate([next_env_state.obs, skills], axis=1)
        if "state_descriptor" in next_env_state.info:
            next_state_desc = next_env_state.info["state_descriptor"]
        else:
            next_state_desc = jnp.zeros((next_env_state.obs.shape[0], 2))
        truncations = next_env_state.info["truncation"]
        transition = QDTransition(
            obs=obs,
            next_obs=next_obs,
            state_desc=state_desc,
            next_state_desc=next_state_desc,
            rewards=next_env_state.reward,
            dones=next_env_state.done,
            actions=actions,
            truncations=truncations,
        )
        training_state = training_state.replace(random_key=random_key)

        return next_env_state, training_state, transition

    @partial(
        jax.jit,
        static_argnames=(
            "self",
            "play_step_fn",
            "env_batch_size",
        ),
    )
    def eval_policy_fn(
        self,
        training_state: DiaynTrainingState,
        eval_env_first_state: EnvState,
        play_step_fn: Callable[
            [EnvState, Params, RNGKey],
            Tuple[EnvState, Params, RNGKey, QDTransition],
        ],
        env_batch_size: int,
    ) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
        """Evaluates the agent's policy over an entire episode, across all batched
        environments.


        Args:
            training_state: the DIAYN training state
            eval_env_first_state: the initial state for evaluation
            play_step_fn: the play_step function used to collect the evaluation episode
            env_batch_size: the number of environments we play simultaneously

        Returns:
            true return averaged over batch dimension, shape: (1,)
            true return per environment, shape: (env_batch_size,)
            diversity return per environment, shape: (env_batch_size,)
            state descriptors, shape: (episode_length, env_batch_size, descriptor_size)

        """

        state, training_state, transitions = generate_unroll(
            init_state=eval_env_first_state,
            training_state=training_state,
            episode_length=self._config.episode_length,
            play_step_fn=play_step_fn,
        )

        transitions = get_first_episode(transitions)
        true_returns = jnp.nansum(transitions.rewards, axis=0)
        true_return = jnp.mean(true_returns, axis=-1)

        reshaped_transitions = jax.tree_util.tree_map(
            lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
            transitions,
        )

        if self._config.descriptor_full_state:
            state_desc = reshaped_transitions.obs[:, : -self._config.num_skills]
            next_state_desc = reshaped_transitions.next_obs[
                :, : -self._config.num_skills
            ]
            reshaped_transitions = reshaped_transitions.replace(
                state_desc=state_desc, next_state_desc=next_state_desc
            )

        diversity_rewards = self._compute_diversity_reward(
            transition=reshaped_transitions,
            discriminator_params=training_state.discriminator_params,
        ).reshape((self._config.episode_length, env_batch_size))

        diversity_returns = jnp.nansum(diversity_rewards, axis=0)

        return (
            true_return,
            true_returns,
            diversity_returns,
            transitions.state_desc,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _compute_reward(
        self, transition: QDTransition, training_state: DiaynTrainingState
    ) -> Reward:
        """Computes the reward to train the networks.

        Args:
            transition: a batch of transitions from the replay buffer
            training_state: the current training state

        Returns:
            the DIAYN diversity reward
        """
        return self._compute_diversity_reward(
            transition=transition,
            discriminator_params=training_state.discriminator_params,
            add_log_p_z=True,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _update_networks(
        self,
        training_state: DiaynTrainingState,
        transitions: QDTransition,
    ) -> Tuple[DiaynTrainingState, Metrics]:
        """Updates all the networks of the training state.

        Args:
            training_state: the current training state.
            transitions: transitions sampled from the replay buffer.
            random_key: a random key to handle stochastic operations.

        Returns:
            The update training state, metrics and a new random key.
        """
        random_key = training_state.random_key

        # Compute discriminator loss and gradients
        discriminator_loss, discriminator_gradient = jax.value_and_grad(
            self._discriminator_loss_fn
        )(
            training_state.discriminator_params,
            transitions=transitions,
        )

        # update discriminator
        (
            discriminator_updates,
            discriminator_optimizer_state,
        ) = self._discriminator_optimizer.update(
            discriminator_gradient, training_state.discriminator_optimizer_state
        )
        discriminator_params = optax.apply_updates(
            training_state.discriminator_params, discriminator_updates
        )

        # udpate alpha
        (
            alpha_params,
            alpha_optimizer_state,
            alpha_loss,
            random_key,
        ) = self._update_alpha(
            alpha_lr=self._config.learning_rate,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update critic
        (
            critic_params,
            target_critic_params,
            critic_optimizer_state,
            critic_loss,
            random_key,
        ) = self._update_critic(
            critic_lr=self._config.learning_rate,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update actor
        (
            policy_params,
            policy_optimizer_state,
            policy_loss,
            random_key,
        ) = self._update_actor(
            policy_lr=self._config.learning_rate,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # Create new training state
        new_training_state = DiaynTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=alpha_params,
            target_critic_params=target_critic_params,
            discriminator_optimizer_state=discriminator_optimizer_state,
            discriminator_params=discriminator_params,
            random_key=random_key,
            steps=training_state.steps + 1,
        )
        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
            "discriminator_loss": discriminator_loss,
            "alpha_loss": alpha_loss,
        }

        return new_training_state, metrics

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        training_state: DiaynTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[DiaynTrainingState, ReplayBuffer, Metrics]:
        """Performs a training step to update the policy, the critic and the
        discriminator parameters.

        Args:
            training_state: the current DIAYN training state
            replay_buffer: the replay buffer

        Returns:
            the updated DIAYN training state
            the replay buffer
            the training metrics
        """
        # Sample a batch of transitions in the buffer
        random_key = training_state.random_key
        transitions, random_key = replay_buffer.sample(
            random_key,
            sample_size=self._config.batch_size,
        )

        # Optionally replace the state descriptor by the observation
        if self._config.descriptor_full_state:
            state_desc = transitions.obs[:, : -self._config.num_skills]
            next_state_desc = transitions.next_obs[:, : -self._config.num_skills]
            transitions = transitions.replace(
                state_desc=state_desc, next_state_desc=next_state_desc
            )

        # Compute the rewards and replace transitions
        rewards = self._compute_reward(transitions, training_state)
        transitions = transitions.replace(rewards=rewards)

        # update params of networks in the training state
        new_training_state, metrics = self._update_networks(
            training_state, transitions=transitions
        )

        return new_training_state, replay_buffer, metrics

init(self, random_key, action_size, observation_size, descriptor_size)

Initialise the training state of the algorithm.

Parameters:
  • random_key (Array) – a jax random key

  • action_size (int) – the size of the environment's action space

  • observation_size (int) – the size of the environment's observation space

  • descriptor_size (int) – the size of the environment's descriptor space (i.e. the dimension of the discriminator's input)

Returns:
  • DiaynTrainingState – the initial training state of DIAYN

Source code in qdax/baselines/diayn.py
def init(  # type: ignore
    self,
    random_key: RNGKey,
    action_size: int,
    observation_size: int,
    descriptor_size: int,
) -> DiaynTrainingState:
    """Initialise the training state of the algorithm.

    Args:
        random_key: a jax random key
        action_size: the size of the environment's action space
        observation_size: the size of the environment's observation space
        descriptor_size: the size of the environment's descriptor space (i.e. the
            dimension of the discriminator's input)

    Returns:
        the initial training state of DIAYN
    """

    # define policy and critic params
    dummy_obs = jnp.zeros((1, observation_size + self._config.num_skills))
    dummy_action = jnp.zeros((1, action_size))
    dummy_discriminator_obs = jnp.zeros((1, descriptor_size))

    random_key, subkey = jax.random.split(random_key)
    policy_params = self._policy.init(subkey, dummy_obs)

    random_key, subkey = jax.random.split(random_key)
    critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

    target_critic_params = jax.tree_util.tree_map(
        lambda x: jnp.asarray(x.copy()), critic_params
    )

    random_key, subkey = jax.random.split(random_key)
    discriminator_params = self._discriminator.init(
        subkey, obs=dummy_discriminator_obs
    )

    policy_optimizer_state = self._policy_optimizer.init(policy_params)
    critic_optimizer_state = self._critic_optimizer.init(critic_params)
    discriminator_optimizer_state = self._discriminator_optimizer.init(
        discriminator_params
    )

    log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
    alpha_optimizer_state = self._alpha_optimizer.init(log_alpha)

    return DiaynTrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_params,
        critic_optimizer_state=critic_optimizer_state,
        critic_params=critic_params,
        alpha_optimizer_state=alpha_optimizer_state,
        alpha_params=log_alpha,
        target_critic_params=target_critic_params,
        discriminator_optimizer_state=discriminator_optimizer_state,
        discriminator_params=discriminator_params,
        random_key=random_key,
        steps=jnp.array(0),
    )

play_step_fn(self, env_state, training_state, skills, env, deterministic=False)

Plays a step in the environment. Concatenates skills to the observation vector, selects an action according to SAC rule and performs the environment step.

Parameters:
  • env_state (State) – the current environment state

  • training_state (DiaynTrainingState) – the DIAYN training state

  • skills (Array) – the skills concatenated to the observation vector

  • env (Env) – the environment

  • deterministic (bool) – whether or not to select action in a deterministic way. Defaults to False.

Returns:
  • Tuple[brax.envs.base.State, qdax.baselines.diayn.DiaynTrainingState, qdax.core.neuroevolution.buffers.buffer.QDTransition] – the new environment state the new DIAYN training state the played transition

Source code in qdax/baselines/diayn.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic"))
def play_step_fn(
    self,
    env_state: EnvState,
    training_state: DiaynTrainingState,
    skills: Skill,
    env: Env,
    deterministic: bool = False,
) -> Tuple[EnvState, DiaynTrainingState, QDTransition]:
    """Plays a step in the environment. Concatenates skills to the observation
    vector, selects an action according to SAC rule and performs the environment
    step.

    Args:
        env_state: the current environment state
        training_state: the DIAYN training state
        skills: the skills concatenated to the observation vector
        env: the environment
        deterministic: whether or not to select action in a deterministic way.
            Defaults to False.

    Returns:
        the new environment state
        the new DIAYN training state
        the played transition
    """

    random_key = training_state.random_key
    policy_params = training_state.policy_params
    obs = jnp.concatenate([env_state.obs, skills], axis=1)

    # If the env does not support state descriptor, we set it to (0,0)
    if "state_descriptor" in env_state.info:
        state_desc = env_state.info["state_descriptor"]
    else:
        state_desc = jnp.zeros((env_state.obs.shape[0], 2))

    actions, random_key = self.select_action(
        obs=obs,
        policy_params=policy_params,
        random_key=random_key,
        deterministic=deterministic,
    )

    next_env_state = env.step(env_state, actions)
    next_obs = jnp.concatenate([next_env_state.obs, skills], axis=1)
    if "state_descriptor" in next_env_state.info:
        next_state_desc = next_env_state.info["state_descriptor"]
    else:
        next_state_desc = jnp.zeros((next_env_state.obs.shape[0], 2))
    truncations = next_env_state.info["truncation"]
    transition = QDTransition(
        obs=obs,
        next_obs=next_obs,
        state_desc=state_desc,
        next_state_desc=next_state_desc,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        actions=actions,
        truncations=truncations,
    )
    training_state = training_state.replace(random_key=random_key)

    return next_env_state, training_state, transition

eval_policy_fn(self, training_state, eval_env_first_state, play_step_fn, env_batch_size)

Evaluates the agent's policy over an entire episode, across all batched environments.

Parameters:
  • training_state (DiaynTrainingState) – the DIAYN training state

  • eval_env_first_state (State) – the initial state for evaluation

  • play_step_fn (Callable[[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, qdax.core.neuroevolution.buffers.buffer.QDTransition]]) – the play_step function used to collect the evaluation episode

  • env_batch_size (int) – the number of environments we play simultaneously

Returns:
  • true return averaged over batch dimension, shape – (1,) true return per environment, shape: (env_batch_size,) diversity return per environment, shape: (env_batch_size,) state descriptors, shape: (episode_length, env_batch_size, descriptor_size)

Source code in qdax/baselines/diayn.py
@partial(
    jax.jit,
    static_argnames=(
        "self",
        "play_step_fn",
        "env_batch_size",
    ),
)
def eval_policy_fn(
    self,
    training_state: DiaynTrainingState,
    eval_env_first_state: EnvState,
    play_step_fn: Callable[
        [EnvState, Params, RNGKey],
        Tuple[EnvState, Params, RNGKey, QDTransition],
    ],
    env_batch_size: int,
) -> Tuple[Reward, Reward, Reward, StateDescriptor]:
    """Evaluates the agent's policy over an entire episode, across all batched
    environments.


    Args:
        training_state: the DIAYN training state
        eval_env_first_state: the initial state for evaluation
        play_step_fn: the play_step function used to collect the evaluation episode
        env_batch_size: the number of environments we play simultaneously

    Returns:
        true return averaged over batch dimension, shape: (1,)
        true return per environment, shape: (env_batch_size,)
        diversity return per environment, shape: (env_batch_size,)
        state descriptors, shape: (episode_length, env_batch_size, descriptor_size)

    """

    state, training_state, transitions = generate_unroll(
        init_state=eval_env_first_state,
        training_state=training_state,
        episode_length=self._config.episode_length,
        play_step_fn=play_step_fn,
    )

    transitions = get_first_episode(transitions)
    true_returns = jnp.nansum(transitions.rewards, axis=0)
    true_return = jnp.mean(true_returns, axis=-1)

    reshaped_transitions = jax.tree_util.tree_map(
        lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
        transitions,
    )

    if self._config.descriptor_full_state:
        state_desc = reshaped_transitions.obs[:, : -self._config.num_skills]
        next_state_desc = reshaped_transitions.next_obs[
            :, : -self._config.num_skills
        ]
        reshaped_transitions = reshaped_transitions.replace(
            state_desc=state_desc, next_state_desc=next_state_desc
        )

    diversity_rewards = self._compute_diversity_reward(
        transition=reshaped_transitions,
        discriminator_params=training_state.discriminator_params,
    ).reshape((self._config.episode_length, env_batch_size))

    diversity_returns = jnp.nansum(diversity_rewards, axis=0)

    return (
        true_return,
        true_returns,
        diversity_returns,
        transitions.state_desc,
    )

update(self, training_state, replay_buffer)

Performs a training step to update the policy, the critic and the discriminator parameters.

Parameters:
  • training_state (DiaynTrainingState) – the current DIAYN training state

  • replay_buffer (ReplayBuffer) – the replay buffer

Returns:
  • Tuple[qdax.baselines.diayn.DiaynTrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer, Dict[str, jax.Array]] – the updated DIAYN training state the replay buffer the training metrics

Source code in qdax/baselines/diayn.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    training_state: DiaynTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[DiaynTrainingState, ReplayBuffer, Metrics]:
    """Performs a training step to update the policy, the critic and the
    discriminator parameters.

    Args:
        training_state: the current DIAYN training state
        replay_buffer: the replay buffer

    Returns:
        the updated DIAYN training state
        the replay buffer
        the training metrics
    """
    # Sample a batch of transitions in the buffer
    random_key = training_state.random_key
    transitions, random_key = replay_buffer.sample(
        random_key,
        sample_size=self._config.batch_size,
    )

    # Optionally replace the state descriptor by the observation
    if self._config.descriptor_full_state:
        state_desc = transitions.obs[:, : -self._config.num_skills]
        next_state_desc = transitions.next_obs[:, : -self._config.num_skills]
        transitions = transitions.replace(
            state_desc=state_desc, next_state_desc=next_state_desc
        )

    # Compute the rewards and replace transitions
    rewards = self._compute_reward(transitions, training_state)
    transitions = transitions.replace(rewards=rewards)

    # update params of networks in the training state
    new_training_state, metrics = self._update_networks(
        training_state, transitions=transitions
    )

    return new_training_state, replay_buffer, metrics