SAC class

qdax.baselines.sac.SAC

Source code in qdax/baselines/sac.py
class SAC:
    def __init__(self, config: SacConfig, action_size: int) -> None:
        self._config = config
        self._action_size = action_size

        # define the networks
        self._policy, self._critic = make_sac_networks(
            action_size=action_size,
            critic_hidden_layer_size=self._config.critic_hidden_layer_size,
            policy_hidden_layer_size=self._config.policy_hidden_layer_size,
        )

        # define the action distribution
        self._parametric_action_distribution = NormalTanhDistribution(
            event_size=action_size
        )
        self._sample_action_fn = self._parametric_action_distribution.sample

    def init(
        self, random_key: RNGKey, action_size: int, observation_size: int
    ) -> SacTrainingState:
        """Initialise the training state of the algorithm.

        Args:
            random_key: a jax random key
            action_size: the size of the environment's action space
            observation_size: the size of the environment's observation space

        Returns:
            the initial training state of SAC
        """

        # define policy and critic params
        dummy_obs = jnp.zeros((1, observation_size))
        dummy_action = jnp.zeros((1, action_size))

        random_key, subkey = jax.random.split(random_key)
        policy_params = self._policy.init(subkey, dummy_obs)

        random_key, subkey = jax.random.split(random_key)
        critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

        target_critic_params = jax.tree_util.tree_map(
            lambda x: jnp.asarray(x.copy()), critic_params
        )

        # define initial optimizer states
        optimizer = optax.adam(learning_rate=1.0)
        policy_optimizer_state = optimizer.init(policy_params)
        critic_optimizer_state = optimizer.init(critic_params)

        log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
        alpha_optimizer_state = optimizer.init(log_alpha)

        # create and retrieve the training state
        training_state = SacTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=log_alpha,
            target_critic_params=target_critic_params,
            normalization_running_stats=RunningMeanStdState(
                mean=jnp.zeros(
                    observation_size,
                ),
                var=jnp.ones(
                    observation_size,
                ),
                count=jnp.zeros(()),
            ),
            random_key=random_key,
            steps=jnp.array(0),
        )

        return training_state

    @partial(jax.jit, static_argnames=("self", "deterministic"))
    def select_action(
        self,
        obs: Observation,
        policy_params: Params,
        random_key: RNGKey,
        deterministic: bool = False,
    ) -> Tuple[Action, RNGKey]:
        """Selects an action acording to SAC policy.

        Args:
            obs: agent observation(s)
            policy_params: parameters of the agent's policy
            random_key: jax random key
            deterministic: whether to select action in a deterministic way.
                Defaults to False.

        Returns:
            The selected action and a new random key.
        """

        dist_params = self._policy.apply(policy_params, obs)
        if not deterministic:
            random_key, key_sample = jax.random.split(random_key)
            actions = self._sample_action_fn(dist_params, key_sample)

        else:
            # The first half of parameters is for mean and the second half for variance
            actions = jax.nn.tanh(dist_params[..., : dist_params.shape[-1] // 2])

        return actions, random_key

    @partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
    def play_step_fn(
        self,
        env_state: EnvState,
        training_state: SacTrainingState,
        env: Env,
        deterministic: bool = False,
        evaluation: bool = False,
    ) -> Tuple[EnvState, SacTrainingState, Transition]:
        """Plays a step in the environment. Selects an action according to SAC rule and
        performs the environment step.

        Args:
            env_state: the current environment state
            training_state: the SAC training state
            env: the environment
            deterministic: the whether to select action in a deterministic way.
                Defaults to False.
            evaluation: if True, collected transitions are not used to update training
                state. Defaults to False.

        Returns:
            the new environment state
            the new SAC training state
            the played transition
        """
        random_key = training_state.random_key
        policy_params = training_state.policy_params
        obs = env_state.obs

        if self._config.normalize_observations:
            normalized_obs = normalize_with_rmstd(
                obs, training_state.normalization_running_stats
            )
            normalization_running_stats = update_running_mean_std(
                training_state.normalization_running_stats, obs
            )

        else:
            normalized_obs = obs
            normalization_running_stats = training_state.normalization_running_stats

        actions, random_key = self.select_action(
            obs=normalized_obs,
            policy_params=policy_params,
            random_key=random_key,
            deterministic=deterministic,
        )

        if not evaluation:
            training_state = training_state.replace(
                random_key=random_key,
                normalization_running_stats=normalization_running_stats,
            )
        else:
            training_state = training_state.replace(
                random_key=random_key,
            )

        next_env_state = env.step(env_state, actions)
        next_obs = next_env_state.obs

        truncations = next_env_state.info["truncation"]
        transition = Transition(
            obs=env_state.obs,
            next_obs=next_obs,
            rewards=next_env_state.reward,
            dones=next_env_state.done,
            actions=actions,
            truncations=truncations,
        )

        return (
            next_env_state,
            training_state,
            transition,
        )

    @partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
    def play_qd_step_fn(
        self,
        env_state: EnvState,
        training_state: SacTrainingState,
        env: Env,
        deterministic: bool = False,
        evaluation: bool = False,
    ) -> Tuple[EnvState, SacTrainingState, QDTransition]:
        """Plays a step in the environment. Selects an action according to SAC rule and
        performs the environment step.

        Args:
            env_state: the current environment state
            training_state: the SAC training state
            env: the environment
            deterministic: the whether to select action in a deterministic way.
                Defaults to False.
            evaluation: if True, collected transitions are not used to update training
                state. Defaults to False.

        Returns:
            the new environment state
            the new SAC training state
            the played transition
        """

        next_env_state, training_state, transition = self.play_step_fn(
            env_state, training_state, env, deterministic, evaluation
        )
        actions = transition.actions
        next_env_state = env.step(env_state, actions)
        next_obs = next_env_state.obs

        truncations = next_env_state.info["truncation"]

        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_obs,
            rewards=next_env_state.reward,
            dones=next_env_state.done,
            actions=actions,
            truncations=truncations,
            state_desc=env_state.info["state_descriptor"],
            next_state_desc=next_env_state.info["state_descriptor"],
        )

        return (
            next_env_state,
            training_state,
            transition,
        )

    @partial(
        jax.jit,
        static_argnames=(
            "self",
            "play_step_fn",
        ),
    )
    def eval_policy_fn(
        self,
        training_state: SacTrainingState,
        eval_env_first_state: EnvState,
        play_step_fn: Callable[
            [EnvState, Params, RNGKey],
            Tuple[EnvState, SacTrainingState, Transition],
        ],
    ) -> Tuple[Reward, Reward]:
        """Evaluates the agent's policy over an entire episode, across all batched
        environments.


        Args:
            training_state: the SAC training state
            eval_env_first_state: the initial state for evaluation
            play_step_fn: the play_step function used to collect the evaluation episode

        Returns:
            the true return averaged over batch dimension, shape: (1,)
            the true return per environment, shape: (env_batch_size,)

        """

        state, training_state, transitions = generate_unroll(
            init_state=eval_env_first_state,
            training_state=training_state,
            episode_length=self._config.episode_length,
            play_step_fn=play_step_fn,
        )

        transitions = get_first_episode(transitions)
        true_returns = jnp.nansum(transitions.rewards, axis=0)
        true_return = jnp.mean(true_returns, axis=-1)

        return true_return, true_returns

    @partial(
        jax.jit,
        static_argnames=(
            "self",
            "play_step_fn",
            "bd_extraction_fn",
        ),
    )
    def eval_qd_policy_fn(
        self,
        training_state: SacTrainingState,
        eval_env_first_state: EnvState,
        play_step_fn: Callable[
            [EnvState, Params, RNGKey],
            Tuple[EnvState, SacTrainingState, QDTransition],
        ],
        bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
    ) -> Tuple[Reward, Descriptor, Reward, Descriptor]:
        """
        Evaluates the agent's policy over an entire episode, across all batched
        environments for QD environments. Averaged BDs are returned as well.


        Args:
            training_state: the SAC training state
            eval_env_first_state: the initial state for evaluation
            play_step_fn: the play_step function used to collect the evaluation episode

        Returns:
            the true return averaged over batch dimension, shape: (1,)
            the descriptor averaged over batch dimension, shape: (num_descriptors,)
            the true return per environment, shape: (env_batch_size,)
            the descriptor per environment, shape: (env_batch_size, num_descriptors)

        """

        state, training_state, transitions = generate_unroll(
            init_state=eval_env_first_state,
            training_state=training_state,
            episode_length=self._config.episode_length,
            play_step_fn=play_step_fn,
        )
        transitions = get_first_episode(transitions)
        true_returns = jnp.nansum(transitions.rewards, axis=0)
        true_return = jnp.mean(true_returns, axis=-1)

        transitions = jax.tree_util.tree_map(
            lambda x: jnp.swapaxes(x, 0, 1), transitions
        )
        masks = jnp.isnan(transitions.rewards)
        bds = bd_extraction_fn(transitions, masks)

        mean_bd = jnp.mean(bds, axis=0)
        return true_return, mean_bd, true_returns, bds

    @partial(jax.jit, static_argnames=("self",))
    def _update_alpha(
        self,
        alpha_lr: float,
        training_state: SacTrainingState,
        transitions: Transition,
        random_key: RNGKey,
    ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]:
        """Updates the alpha parameter if necessary. Else, it keeps the
        current value.

        Args:
            alpha_lr: alpha learning rate
            training_state: the current training state.
            transitions: a sample of transitions from the replay buffer.
            random_key: a random key to handle stochastic operations.

        Returns:
            New alpha params, optimizer state, loss and a new random key.
        """
        if not self._config.fix_alpha:
            # update alpha
            random_key, subkey = jax.random.split(random_key)
            alpha_loss, alpha_gradient = jax.value_and_grad(sac_alpha_loss_fn)(
                training_state.alpha_params,
                policy_fn=self._policy.apply,
                parametric_action_distribution=self._parametric_action_distribution,
                action_size=self._action_size,
                policy_params=training_state.policy_params,
                transitions=transitions,
                random_key=subkey,
            )
            alpha_optimizer = optax.adam(learning_rate=alpha_lr)
            (alpha_updates, alpha_optimizer_state,) = alpha_optimizer.update(
                alpha_gradient, training_state.alpha_optimizer_state
            )
            alpha_params = optax.apply_updates(
                training_state.alpha_params, alpha_updates
            )
        else:
            alpha_params = training_state.alpha_params
            alpha_optimizer_state = training_state.alpha_optimizer_state
            alpha_loss = jnp.array(0.0)

        return alpha_params, alpha_optimizer_state, alpha_loss, random_key

    @partial(jax.jit, static_argnames=("self",))
    def _update_critic(
        self,
        critic_lr: float,
        reward_scaling: float,
        discount: float,
        training_state: SacTrainingState,
        transitions: Transition,
        random_key: RNGKey,
    ) -> Tuple[Params, Params, optax.OptState, jnp.ndarray, RNGKey]:
        """Updates the critic following the method described in the
        Soft Actor Critic paper.

        Args:
            critic_lr: critic learning rate
            reward_scaling: coefficient to scale rewards
            discount: discount factor
            training_state: the current training state.
            transitions: a batch of transitions sampled from the replay buffer.
            random_key: a random key to handle stochastic operations.

        Returns:
            New parameters of the critic and its target. New optimizer state,
            loss and a new random key.
        """
        # update critic
        random_key, subkey = jax.random.split(random_key)
        critic_loss, critic_gradient = jax.value_and_grad(sac_critic_loss_fn)(
            training_state.critic_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            parametric_action_distribution=self._parametric_action_distribution,
            reward_scaling=reward_scaling,
            discount=discount,
            policy_params=training_state.policy_params,
            target_critic_params=training_state.target_critic_params,
            alpha=jnp.exp(training_state.alpha_params),
            transitions=transitions,
            random_key=subkey,
        )
        critic_optimizer = optax.adam(learning_rate=critic_lr)
        (critic_updates, critic_optimizer_state,) = critic_optimizer.update(
            critic_gradient, training_state.critic_optimizer_state
        )
        critic_params = optax.apply_updates(
            training_state.critic_params, critic_updates
        )
        target_critic_params = jax.tree_util.tree_map(
            lambda x1, x2: (1.0 - self._config.tau) * x1 + self._config.tau * x2,
            training_state.target_critic_params,
            critic_params,
        )

        return (
            critic_params,
            target_critic_params,
            critic_optimizer_state,
            critic_loss,
            random_key,
        )

    @partial(jax.jit, static_argnames=("self",))
    def _update_actor(
        self,
        policy_lr: float,
        training_state: SacTrainingState,
        transitions: Transition,
        random_key: RNGKey,
    ) -> Tuple[Params, optax.OptState, jnp.ndarray, RNGKey]:
        """Updates the actor parameters following the stochastic
        policy gradient theorem with the method introduced in SAC.

        Args:
            policy_lr: policy learning rate
            training_state: the current training state.
            transitions: a batch of transitions sampled from the replay
                buffer.
            random_key: a random key to handle stochastic operations.

        Returns:
            New params and optimizer state. Current loss. New random key.
        """
        random_key, subkey = jax.random.split(random_key)
        policy_loss, policy_gradient = jax.value_and_grad(sac_policy_loss_fn)(
            training_state.policy_params,
            policy_fn=self._policy.apply,
            critic_fn=self._critic.apply,
            parametric_action_distribution=self._parametric_action_distribution,
            critic_params=training_state.critic_params,
            alpha=jnp.exp(training_state.alpha_params),
            transitions=transitions,
            random_key=subkey,
        )
        policy_optimizer = optax.adam(learning_rate=policy_lr)
        (policy_updates, policy_optimizer_state,) = policy_optimizer.update(
            policy_gradient, training_state.policy_optimizer_state
        )
        policy_params = optax.apply_updates(
            training_state.policy_params, policy_updates
        )

        return policy_params, policy_optimizer_state, policy_loss, random_key

    @partial(jax.jit, static_argnames=("self",))
    def update(
        self,
        training_state: SacTrainingState,
        replay_buffer: ReplayBuffer,
    ) -> Tuple[SacTrainingState, ReplayBuffer, Metrics]:
        """Performs a training step to update the policy and the critic parameters.

        Args:
            training_state: the current SAC training state
            replay_buffer: the replay buffer

        Returns:
            the updated SAC training state
            the replay buffer
            the training metrics
        """

        # sample a batch of transitions in the buffer
        random_key = training_state.random_key
        transitions, random_key = replay_buffer.sample(
            random_key,
            sample_size=self._config.batch_size,
        )

        # normalise observations if necessary
        if self._config.normalize_observations:
            normalization_running_stats = training_state.normalization_running_stats
            normalized_obs = normalize_with_rmstd(
                transitions.obs, normalization_running_stats
            )
            normalized_next_obs = normalize_with_rmstd(
                transitions.next_obs, normalization_running_stats
            )
            transitions = transitions.replace(
                obs=normalized_obs, next_obs=normalized_next_obs
            )

        # update alpha
        (
            alpha_params,
            alpha_optimizer_state,
            alpha_loss,
            random_key,
        ) = self._update_alpha(
            alpha_lr=self._config.learning_rate,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update critic
        (
            critic_params,
            target_critic_params,
            critic_optimizer_state,
            critic_loss,
            random_key,
        ) = self._update_critic(
            critic_lr=self._config.learning_rate,
            reward_scaling=self._config.reward_scaling,
            discount=self._config.discount,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # update actor
        (
            policy_params,
            policy_optimizer_state,
            policy_loss,
            random_key,
        ) = self._update_actor(
            policy_lr=self._config.learning_rate,
            training_state=training_state,
            transitions=transitions,
            random_key=random_key,
        )

        # create new training state
        new_training_state = SacTrainingState(
            policy_optimizer_state=policy_optimizer_state,
            policy_params=policy_params,
            critic_optimizer_state=critic_optimizer_state,
            critic_params=critic_params,
            alpha_optimizer_state=alpha_optimizer_state,
            alpha_params=alpha_params,
            normalization_running_stats=training_state.normalization_running_stats,
            target_critic_params=target_critic_params,
            random_key=random_key,
            steps=training_state.steps + 1,
        )
        metrics = {
            "actor_loss": policy_loss,
            "critic_loss": critic_loss,
            "alpha_loss": alpha_loss,
            "obs_mean": jnp.mean(transitions.obs),
            "obs_std": jnp.std(transitions.obs),
        }
        return new_training_state, replay_buffer, metrics

init(self, random_key, action_size, observation_size)

Initialise the training state of the algorithm.

Parameters:
  • random_key (Array) – a jax random key

  • action_size (int) – the size of the environment's action space

  • observation_size (int) – the size of the environment's observation space

Returns:
  • SacTrainingState – the initial training state of SAC

Source code in qdax/baselines/sac.py
def init(
    self, random_key: RNGKey, action_size: int, observation_size: int
) -> SacTrainingState:
    """Initialise the training state of the algorithm.

    Args:
        random_key: a jax random key
        action_size: the size of the environment's action space
        observation_size: the size of the environment's observation space

    Returns:
        the initial training state of SAC
    """

    # define policy and critic params
    dummy_obs = jnp.zeros((1, observation_size))
    dummy_action = jnp.zeros((1, action_size))

    random_key, subkey = jax.random.split(random_key)
    policy_params = self._policy.init(subkey, dummy_obs)

    random_key, subkey = jax.random.split(random_key)
    critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

    target_critic_params = jax.tree_util.tree_map(
        lambda x: jnp.asarray(x.copy()), critic_params
    )

    # define initial optimizer states
    optimizer = optax.adam(learning_rate=1.0)
    policy_optimizer_state = optimizer.init(policy_params)
    critic_optimizer_state = optimizer.init(critic_params)

    log_alpha = jnp.asarray(jnp.log(self._config.alpha_init), dtype=jnp.float32)
    alpha_optimizer_state = optimizer.init(log_alpha)

    # create and retrieve the training state
    training_state = SacTrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_params,
        critic_optimizer_state=critic_optimizer_state,
        critic_params=critic_params,
        alpha_optimizer_state=alpha_optimizer_state,
        alpha_params=log_alpha,
        target_critic_params=target_critic_params,
        normalization_running_stats=RunningMeanStdState(
            mean=jnp.zeros(
                observation_size,
            ),
            var=jnp.ones(
                observation_size,
            ),
            count=jnp.zeros(()),
        ),
        random_key=random_key,
        steps=jnp.array(0),
    )

    return training_state

select_action(self, obs, policy_params, random_key, deterministic=False)

Selects an action acording to SAC policy.

Parameters:
  • obs (Array) – agent observation(s)

  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the agent's policy

  • random_key (Array) – jax random key

  • deterministic (bool) – whether to select action in a deterministic way. Defaults to False.

Returns:
  • Tuple[jax.Array, jax.Array] – The selected action and a new random key.

Source code in qdax/baselines/sac.py
@partial(jax.jit, static_argnames=("self", "deterministic"))
def select_action(
    self,
    obs: Observation,
    policy_params: Params,
    random_key: RNGKey,
    deterministic: bool = False,
) -> Tuple[Action, RNGKey]:
    """Selects an action acording to SAC policy.

    Args:
        obs: agent observation(s)
        policy_params: parameters of the agent's policy
        random_key: jax random key
        deterministic: whether to select action in a deterministic way.
            Defaults to False.

    Returns:
        The selected action and a new random key.
    """

    dist_params = self._policy.apply(policy_params, obs)
    if not deterministic:
        random_key, key_sample = jax.random.split(random_key)
        actions = self._sample_action_fn(dist_params, key_sample)

    else:
        # The first half of parameters is for mean and the second half for variance
        actions = jax.nn.tanh(dist_params[..., : dist_params.shape[-1] // 2])

    return actions, random_key

play_step_fn(self, env_state, training_state, env, deterministic=False, evaluation=False)

Plays a step in the environment. Selects an action according to SAC rule and performs the environment step.

Parameters:
  • env_state (State) – the current environment state

  • training_state (SacTrainingState) – the SAC training state

  • env (Env) – the environment

  • deterministic (bool) – the whether to select action in a deterministic way. Defaults to False.

  • evaluation (bool) – if True, collected transitions are not used to update training state. Defaults to False.

Returns:
  • Tuple[brax.envs.base.State, qdax.baselines.sac.SacTrainingState, qdax.core.neuroevolution.buffers.buffer.Transition] – the new environment state the new SAC training state the played transition

Source code in qdax/baselines/sac.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
def play_step_fn(
    self,
    env_state: EnvState,
    training_state: SacTrainingState,
    env: Env,
    deterministic: bool = False,
    evaluation: bool = False,
) -> Tuple[EnvState, SacTrainingState, Transition]:
    """Plays a step in the environment. Selects an action according to SAC rule and
    performs the environment step.

    Args:
        env_state: the current environment state
        training_state: the SAC training state
        env: the environment
        deterministic: the whether to select action in a deterministic way.
            Defaults to False.
        evaluation: if True, collected transitions are not used to update training
            state. Defaults to False.

    Returns:
        the new environment state
        the new SAC training state
        the played transition
    """
    random_key = training_state.random_key
    policy_params = training_state.policy_params
    obs = env_state.obs

    if self._config.normalize_observations:
        normalized_obs = normalize_with_rmstd(
            obs, training_state.normalization_running_stats
        )
        normalization_running_stats = update_running_mean_std(
            training_state.normalization_running_stats, obs
        )

    else:
        normalized_obs = obs
        normalization_running_stats = training_state.normalization_running_stats

    actions, random_key = self.select_action(
        obs=normalized_obs,
        policy_params=policy_params,
        random_key=random_key,
        deterministic=deterministic,
    )

    if not evaluation:
        training_state = training_state.replace(
            random_key=random_key,
            normalization_running_stats=normalization_running_stats,
        )
    else:
        training_state = training_state.replace(
            random_key=random_key,
        )

    next_env_state = env.step(env_state, actions)
    next_obs = next_env_state.obs

    truncations = next_env_state.info["truncation"]
    transition = Transition(
        obs=env_state.obs,
        next_obs=next_obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        actions=actions,
        truncations=truncations,
    )

    return (
        next_env_state,
        training_state,
        transition,
    )

play_qd_step_fn(self, env_state, training_state, env, deterministic=False, evaluation=False)

Plays a step in the environment. Selects an action according to SAC rule and performs the environment step.

Parameters:
  • env_state (State) – the current environment state

  • training_state (SacTrainingState) – the SAC training state

  • env (Env) – the environment

  • deterministic (bool) – the whether to select action in a deterministic way. Defaults to False.

  • evaluation (bool) – if True, collected transitions are not used to update training state. Defaults to False.

Returns:
  • Tuple[brax.envs.base.State, qdax.baselines.sac.SacTrainingState, qdax.core.neuroevolution.buffers.buffer.QDTransition] – the new environment state the new SAC training state the played transition

Source code in qdax/baselines/sac.py
@partial(jax.jit, static_argnames=("self", "env", "deterministic", "evaluation"))
def play_qd_step_fn(
    self,
    env_state: EnvState,
    training_state: SacTrainingState,
    env: Env,
    deterministic: bool = False,
    evaluation: bool = False,
) -> Tuple[EnvState, SacTrainingState, QDTransition]:
    """Plays a step in the environment. Selects an action according to SAC rule and
    performs the environment step.

    Args:
        env_state: the current environment state
        training_state: the SAC training state
        env: the environment
        deterministic: the whether to select action in a deterministic way.
            Defaults to False.
        evaluation: if True, collected transitions are not used to update training
            state. Defaults to False.

    Returns:
        the new environment state
        the new SAC training state
        the played transition
    """

    next_env_state, training_state, transition = self.play_step_fn(
        env_state, training_state, env, deterministic, evaluation
    )
    actions = transition.actions
    next_env_state = env.step(env_state, actions)
    next_obs = next_env_state.obs

    truncations = next_env_state.info["truncation"]

    transition = QDTransition(
        obs=env_state.obs,
        next_obs=next_obs,
        rewards=next_env_state.reward,
        dones=next_env_state.done,
        actions=actions,
        truncations=truncations,
        state_desc=env_state.info["state_descriptor"],
        next_state_desc=next_env_state.info["state_descriptor"],
    )

    return (
        next_env_state,
        training_state,
        transition,
    )

eval_policy_fn(self, training_state, eval_env_first_state, play_step_fn)

Evaluates the agent's policy over an entire episode, across all batched environments.

Parameters:
  • training_state (SacTrainingState) – the SAC training state

  • eval_env_first_state (State) – the initial state for evaluation

  • play_step_fn (Callable[[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[brax.envs.base.State, qdax.baselines.sac.SacTrainingState, qdax.core.neuroevolution.buffers.buffer.Transition]]) – the play_step function used to collect the evaluation episode

Returns:
  • the true return averaged over batch dimension, shape – (1,) the true return per environment, shape: (env_batch_size,)

Source code in qdax/baselines/sac.py
@partial(
    jax.jit,
    static_argnames=(
        "self",
        "play_step_fn",
    ),
)
def eval_policy_fn(
    self,
    training_state: SacTrainingState,
    eval_env_first_state: EnvState,
    play_step_fn: Callable[
        [EnvState, Params, RNGKey],
        Tuple[EnvState, SacTrainingState, Transition],
    ],
) -> Tuple[Reward, Reward]:
    """Evaluates the agent's policy over an entire episode, across all batched
    environments.


    Args:
        training_state: the SAC training state
        eval_env_first_state: the initial state for evaluation
        play_step_fn: the play_step function used to collect the evaluation episode

    Returns:
        the true return averaged over batch dimension, shape: (1,)
        the true return per environment, shape: (env_batch_size,)

    """

    state, training_state, transitions = generate_unroll(
        init_state=eval_env_first_state,
        training_state=training_state,
        episode_length=self._config.episode_length,
        play_step_fn=play_step_fn,
    )

    transitions = get_first_episode(transitions)
    true_returns = jnp.nansum(transitions.rewards, axis=0)
    true_return = jnp.mean(true_returns, axis=-1)

    return true_return, true_returns

eval_qd_policy_fn(self, training_state, eval_env_first_state, play_step_fn, bd_extraction_fn)

Evaluates the agent's policy over an entire episode, across all batched environments for QD environments. Averaged BDs are returned as well.

Parameters:
  • training_state (SacTrainingState) – the SAC training state

  • eval_env_first_state (State) – the initial state for evaluation

  • play_step_fn (Callable[[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[brax.envs.base.State, qdax.baselines.sac.SacTrainingState, qdax.core.neuroevolution.buffers.buffer.QDTransition]]) – the play_step function used to collect the evaluation episode

Returns:
  • the true return averaged over batch dimension, shape – (1,) the descriptor averaged over batch dimension, shape: (num_descriptors,) the true return per environment, shape: (env_batch_size,) the descriptor per environment, shape: (env_batch_size, num_descriptors)

Source code in qdax/baselines/sac.py
@partial(
    jax.jit,
    static_argnames=(
        "self",
        "play_step_fn",
        "bd_extraction_fn",
    ),
)
def eval_qd_policy_fn(
    self,
    training_state: SacTrainingState,
    eval_env_first_state: EnvState,
    play_step_fn: Callable[
        [EnvState, Params, RNGKey],
        Tuple[EnvState, SacTrainingState, QDTransition],
    ],
    bd_extraction_fn: Callable[[QDTransition, Mask], Descriptor],
) -> Tuple[Reward, Descriptor, Reward, Descriptor]:
    """
    Evaluates the agent's policy over an entire episode, across all batched
    environments for QD environments. Averaged BDs are returned as well.


    Args:
        training_state: the SAC training state
        eval_env_first_state: the initial state for evaluation
        play_step_fn: the play_step function used to collect the evaluation episode

    Returns:
        the true return averaged over batch dimension, shape: (1,)
        the descriptor averaged over batch dimension, shape: (num_descriptors,)
        the true return per environment, shape: (env_batch_size,)
        the descriptor per environment, shape: (env_batch_size, num_descriptors)

    """

    state, training_state, transitions = generate_unroll(
        init_state=eval_env_first_state,
        training_state=training_state,
        episode_length=self._config.episode_length,
        play_step_fn=play_step_fn,
    )
    transitions = get_first_episode(transitions)
    true_returns = jnp.nansum(transitions.rewards, axis=0)
    true_return = jnp.mean(true_returns, axis=-1)

    transitions = jax.tree_util.tree_map(
        lambda x: jnp.swapaxes(x, 0, 1), transitions
    )
    masks = jnp.isnan(transitions.rewards)
    bds = bd_extraction_fn(transitions, masks)

    mean_bd = jnp.mean(bds, axis=0)
    return true_return, mean_bd, true_returns, bds

update(self, training_state, replay_buffer)

Performs a training step to update the policy and the critic parameters.

Parameters:
  • training_state (SacTrainingState) – the current SAC training state

  • replay_buffer (ReplayBuffer) – the replay buffer

Returns:
  • Tuple[qdax.baselines.sac.SacTrainingState, qdax.core.neuroevolution.buffers.buffer.ReplayBuffer, Dict[str, jax.Array]] – the updated SAC training state the replay buffer the training metrics

Source code in qdax/baselines/sac.py
@partial(jax.jit, static_argnames=("self",))
def update(
    self,
    training_state: SacTrainingState,
    replay_buffer: ReplayBuffer,
) -> Tuple[SacTrainingState, ReplayBuffer, Metrics]:
    """Performs a training step to update the policy and the critic parameters.

    Args:
        training_state: the current SAC training state
        replay_buffer: the replay buffer

    Returns:
        the updated SAC training state
        the replay buffer
        the training metrics
    """

    # sample a batch of transitions in the buffer
    random_key = training_state.random_key
    transitions, random_key = replay_buffer.sample(
        random_key,
        sample_size=self._config.batch_size,
    )

    # normalise observations if necessary
    if self._config.normalize_observations:
        normalization_running_stats = training_state.normalization_running_stats
        normalized_obs = normalize_with_rmstd(
            transitions.obs, normalization_running_stats
        )
        normalized_next_obs = normalize_with_rmstd(
            transitions.next_obs, normalization_running_stats
        )
        transitions = transitions.replace(
            obs=normalized_obs, next_obs=normalized_next_obs
        )

    # update alpha
    (
        alpha_params,
        alpha_optimizer_state,
        alpha_loss,
        random_key,
    ) = self._update_alpha(
        alpha_lr=self._config.learning_rate,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # update critic
    (
        critic_params,
        target_critic_params,
        critic_optimizer_state,
        critic_loss,
        random_key,
    ) = self._update_critic(
        critic_lr=self._config.learning_rate,
        reward_scaling=self._config.reward_scaling,
        discount=self._config.discount,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # update actor
    (
        policy_params,
        policy_optimizer_state,
        policy_loss,
        random_key,
    ) = self._update_actor(
        policy_lr=self._config.learning_rate,
        training_state=training_state,
        transitions=transitions,
        random_key=random_key,
    )

    # create new training state
    new_training_state = SacTrainingState(
        policy_optimizer_state=policy_optimizer_state,
        policy_params=policy_params,
        critic_optimizer_state=critic_optimizer_state,
        critic_params=critic_params,
        alpha_optimizer_state=alpha_optimizer_state,
        alpha_params=alpha_params,
        normalization_running_stats=training_state.normalization_running_stats,
        target_critic_params=target_critic_params,
        random_key=random_key,
        steps=training_state.steps + 1,
    )
    metrics = {
        "actor_loss": policy_loss,
        "critic_loss": critic_loss,
        "alpha_loss": alpha_loss,
        "obs_mean": jnp.mean(transitions.obs),
        "obs_std": jnp.std(transitions.obs),
    }
    return new_training_state, replay_buffer, metrics