Neuro-evolution components

qdax.core.neuroevolution special

buffers special

buffer

Transition (PyTreeNode) dataclass

Stores data corresponding to a transition collected by a classic RL algorithm.

Source code in qdax/core/neuroevolution/buffers/buffer.py
class Transition(flax.struct.PyTreeNode):
    """Stores data corresponding to a transition collected by a classic RL algorithm."""

    obs: Observation
    next_obs: Observation
    rewards: Reward
    dones: Done
    truncations: jnp.ndarray  # Indicates if an episode has reached max time step
    actions: Action

    @property
    def observation_dim(self) -> int:
        """
        Returns:
            the dimension of the observation
        """
        return self.obs.shape[-1]  # type: ignore

    @property
    def action_dim(self) -> int:
        """
        Returns:
            the dimension of the action
        """
        return self.actions.shape[-1]  # type: ignore

    @property
    def flatten_dim(self) -> int:
        """
        Returns:
            the dimension of the transition once flattened.

        """
        flatten_dim = 2 * self.observation_dim + self.action_dim + 3
        return flatten_dim

    def flatten(self) -> jnp.ndarray:
        """
        Returns:
            a jnp.ndarray that corresponds to the flattened transition.
        """
        flatten_transition = jnp.concatenate(
            [
                self.obs,
                self.next_obs,
                jnp.expand_dims(self.rewards, axis=-1),
                jnp.expand_dims(self.dones, axis=-1),
                jnp.expand_dims(self.truncations, axis=-1),
                self.actions,
            ],
            axis=-1,
        )
        return flatten_transition

    @classmethod
    def from_flatten(
        cls,
        flattened_transition: jnp.ndarray,
        transition: Transition,
    ) -> Transition:
        """
        Creates a transition from a flattened transition in a jnp.ndarray.

        Args:
            flattened_transition: flattened transition in a jnp.ndarray of shape
                (batch_size, flatten_dim)
            transition: a transition object (might be a dummy one) to
                get the dimensions right

        Returns:
            a Transition object
        """
        obs_dim = transition.observation_dim
        action_dim = transition.action_dim
        obs = flattened_transition[:, :obs_dim]
        next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
        rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
        dones = jnp.ravel(
            flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
        )
        truncations = jnp.ravel(
            flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
        )
        actions = flattened_transition[
            :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
        ]
        return cls(
            obs=obs,
            next_obs=next_obs,
            rewards=rewards,
            dones=dones,
            truncations=truncations,
            actions=actions,
        )

    @classmethod
    def init_dummy(cls, observation_dim: int, action_dim: int) -> Transition:
        """
        Initialize a dummy transition that then can be passed to constructors to get
        all shapes right.

        Args:
            observation_dim: observation dimension
            action_dim: action dimension

        Returns:
            a dummy transition
        """
        dummy_transition = Transition(
            obs=jnp.zeros(shape=(1, observation_dim)),
            next_obs=jnp.zeros(shape=(1, observation_dim)),
            rewards=jnp.zeros(shape=(1,)),
            dones=jnp.zeros(shape=(1,)),
            truncations=jnp.zeros(shape=(1,)),
            actions=jnp.zeros(shape=(1, action_dim)),
        )
        return dummy_transition
observation_dim: int property readonly
Returns:
  • int – the dimension of the observation

action_dim: int property readonly
Returns:
  • int – the dimension of the action

flatten_dim: int property readonly
Returns:
  • int – the dimension of the transition once flattened.

flatten(self)
Returns:
  • jnp.ndarray – a jnp.ndarray that corresponds to the flattened transition.

Source code in qdax/core/neuroevolution/buffers/buffer.py
def flatten(self) -> jnp.ndarray:
    """
    Returns:
        a jnp.ndarray that corresponds to the flattened transition.
    """
    flatten_transition = jnp.concatenate(
        [
            self.obs,
            self.next_obs,
            jnp.expand_dims(self.rewards, axis=-1),
            jnp.expand_dims(self.dones, axis=-1),
            jnp.expand_dims(self.truncations, axis=-1),
            self.actions,
        ],
        axis=-1,
    )
    return flatten_transition
from_flatten(flattened_transition, transition) classmethod

Creates a transition from a flattened transition in a jnp.ndarray.

Parameters:
  • flattened_transition (jnp.ndarray) – flattened transition in a jnp.ndarray of shape (batch_size, flatten_dim)

  • transition (Transition) – a transition object (might be a dummy one) to get the dimensions right

Returns:
  • Transition – a Transition object

Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def from_flatten(
    cls,
    flattened_transition: jnp.ndarray,
    transition: Transition,
) -> Transition:
    """
    Creates a transition from a flattened transition in a jnp.ndarray.

    Args:
        flattened_transition: flattened transition in a jnp.ndarray of shape
            (batch_size, flatten_dim)
        transition: a transition object (might be a dummy one) to
            get the dimensions right

    Returns:
        a Transition object
    """
    obs_dim = transition.observation_dim
    action_dim = transition.action_dim
    obs = flattened_transition[:, :obs_dim]
    next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
    rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
    dones = jnp.ravel(
        flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
    )
    truncations = jnp.ravel(
        flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
    )
    actions = flattened_transition[
        :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
    ]
    return cls(
        obs=obs,
        next_obs=next_obs,
        rewards=rewards,
        dones=dones,
        truncations=truncations,
        actions=actions,
    )
init_dummy(observation_dim, action_dim) classmethod

Initialize a dummy transition that then can be passed to constructors to get all shapes right.

Parameters:
  • observation_dim (int) – observation dimension

  • action_dim (int) – action dimension

Returns:
  • Transition – a dummy transition

Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init_dummy(cls, observation_dim: int, action_dim: int) -> Transition:
    """
    Initialize a dummy transition that then can be passed to constructors to get
    all shapes right.

    Args:
        observation_dim: observation dimension
        action_dim: action dimension

    Returns:
        a dummy transition
    """
    dummy_transition = Transition(
        obs=jnp.zeros(shape=(1, observation_dim)),
        next_obs=jnp.zeros(shape=(1, observation_dim)),
        rewards=jnp.zeros(shape=(1,)),
        dones=jnp.zeros(shape=(1,)),
        truncations=jnp.zeros(shape=(1,)),
        actions=jnp.zeros(shape=(1, action_dim)),
    )
    return dummy_transition
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
QDTransition (Transition) dataclass

Stores data corresponding to a transition collected by a QD algorithm.

Source code in qdax/core/neuroevolution/buffers/buffer.py
class QDTransition(Transition):
    """Stores data corresponding to a transition collected by a QD algorithm."""

    state_desc: StateDescriptor
    next_state_desc: StateDescriptor

    @property
    def state_descriptor_dim(self) -> int:
        """
        Returns:
            the dimension of the state descriptors.

        """
        return self.state_desc.shape[-1]  # type: ignore

    @property
    def flatten_dim(self) -> int:
        """
        Returns:
            the dimension of the transition once flattened.

        """
        flatten_dim = (
            2 * self.observation_dim
            + self.action_dim
            + 3
            + 2 * self.state_descriptor_dim
        )
        return flatten_dim

    def flatten(self) -> jnp.ndarray:
        """
        Returns:
            a jnp.ndarray that corresponds to the flattened transition.
        """
        flatten_transition = jnp.concatenate(
            [
                self.obs,
                self.next_obs,
                jnp.expand_dims(self.rewards, axis=-1),
                jnp.expand_dims(self.dones, axis=-1),
                jnp.expand_dims(self.truncations, axis=-1),
                self.actions,
                self.state_desc,
                self.next_state_desc,
            ],
            axis=-1,
        )
        return flatten_transition

    @classmethod
    def from_flatten(
        cls,
        flattened_transition: jnp.ndarray,
        transition: QDTransition,
    ) -> QDTransition:
        """
        Creates a transition from a flattened transition in a jnp.ndarray.

        Args:
            flattened_transition: flattened transition in a jnp.ndarray of shape
                (batch_size, flatten_dim)
            transition: a transition object (might be a dummy one) to
                get the dimensions right

        Returns:
            a Transition object
        """
        obs_dim = transition.observation_dim
        action_dim = transition.action_dim
        desc_dim = transition.state_descriptor_dim

        obs = flattened_transition[:, :obs_dim]
        next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
        rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
        dones = jnp.ravel(
            flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
        )
        truncations = jnp.ravel(
            flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
        )
        actions = flattened_transition[
            :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
        ]
        state_desc = flattened_transition[
            :,
            (2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + desc_dim),
        ]
        next_state_desc = flattened_transition[
            :,
            (2 * obs_dim + 3 + action_dim + desc_dim) : (
                2 * obs_dim + 3 + action_dim + 2 * desc_dim
            ),
        ]
        return cls(
            obs=obs,
            next_obs=next_obs,
            rewards=rewards,
            dones=dones,
            truncations=truncations,
            actions=actions,
            state_desc=state_desc,
            next_state_desc=next_state_desc,
        )

    @classmethod
    def init_dummy(  # type: ignore
        cls, observation_dim: int, action_dim: int, descriptor_dim: int
    ) -> QDTransition:
        """
        Initialize a dummy transition that then can be passed to constructors to get
        all shapes right.

        Args:
            observation_dim: observation dimension
            action_dim: action dimension

        Returns:
            a dummy transition
        """
        dummy_transition = QDTransition(
            obs=jnp.zeros(shape=(1, observation_dim)),
            next_obs=jnp.zeros(shape=(1, observation_dim)),
            rewards=jnp.zeros(shape=(1,)),
            dones=jnp.zeros(shape=(1,)),
            truncations=jnp.zeros(shape=(1,)),
            actions=jnp.zeros(shape=(1, action_dim)),
            state_desc=jnp.zeros(shape=(1, descriptor_dim)),
            next_state_desc=jnp.zeros(shape=(1, descriptor_dim)),
        )
        return dummy_transition
state_descriptor_dim: int property readonly
Returns:
  • int – the dimension of the state descriptors.

flatten_dim: int property readonly
Returns:
  • int – the dimension of the transition once flattened.

replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
flatten(self)
Returns:
  • jnp.ndarray – a jnp.ndarray that corresponds to the flattened transition.

Source code in qdax/core/neuroevolution/buffers/buffer.py
def flatten(self) -> jnp.ndarray:
    """
    Returns:
        a jnp.ndarray that corresponds to the flattened transition.
    """
    flatten_transition = jnp.concatenate(
        [
            self.obs,
            self.next_obs,
            jnp.expand_dims(self.rewards, axis=-1),
            jnp.expand_dims(self.dones, axis=-1),
            jnp.expand_dims(self.truncations, axis=-1),
            self.actions,
            self.state_desc,
            self.next_state_desc,
        ],
        axis=-1,
    )
    return flatten_transition
from_flatten(flattened_transition, transition) classmethod

Creates a transition from a flattened transition in a jnp.ndarray.

Parameters:
  • flattened_transition (jnp.ndarray) – flattened transition in a jnp.ndarray of shape (batch_size, flatten_dim)

  • transition (QDTransition) – a transition object (might be a dummy one) to get the dimensions right

Returns:
  • QDTransition – a Transition object

Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def from_flatten(
    cls,
    flattened_transition: jnp.ndarray,
    transition: QDTransition,
) -> QDTransition:
    """
    Creates a transition from a flattened transition in a jnp.ndarray.

    Args:
        flattened_transition: flattened transition in a jnp.ndarray of shape
            (batch_size, flatten_dim)
        transition: a transition object (might be a dummy one) to
            get the dimensions right

    Returns:
        a Transition object
    """
    obs_dim = transition.observation_dim
    action_dim = transition.action_dim
    desc_dim = transition.state_descriptor_dim

    obs = flattened_transition[:, :obs_dim]
    next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
    rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
    dones = jnp.ravel(
        flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
    )
    truncations = jnp.ravel(
        flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
    )
    actions = flattened_transition[
        :, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
    ]
    state_desc = flattened_transition[
        :,
        (2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + desc_dim),
    ]
    next_state_desc = flattened_transition[
        :,
        (2 * obs_dim + 3 + action_dim + desc_dim) : (
            2 * obs_dim + 3 + action_dim + 2 * desc_dim
        ),
    ]
    return cls(
        obs=obs,
        next_obs=next_obs,
        rewards=rewards,
        dones=dones,
        truncations=truncations,
        actions=actions,
        state_desc=state_desc,
        next_state_desc=next_state_desc,
    )
init_dummy(observation_dim, action_dim, descriptor_dim) classmethod

Initialize a dummy transition that then can be passed to constructors to get all shapes right.

Parameters:
  • observation_dim (int) – observation dimension

  • action_dim (int) – action dimension

Returns:
  • QDTransition – a dummy transition

Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init_dummy(  # type: ignore
    cls, observation_dim: int, action_dim: int, descriptor_dim: int
) -> QDTransition:
    """
    Initialize a dummy transition that then can be passed to constructors to get
    all shapes right.

    Args:
        observation_dim: observation dimension
        action_dim: action dimension

    Returns:
        a dummy transition
    """
    dummy_transition = QDTransition(
        obs=jnp.zeros(shape=(1, observation_dim)),
        next_obs=jnp.zeros(shape=(1, observation_dim)),
        rewards=jnp.zeros(shape=(1,)),
        dones=jnp.zeros(shape=(1,)),
        truncations=jnp.zeros(shape=(1,)),
        actions=jnp.zeros(shape=(1, action_dim)),
        state_desc=jnp.zeros(shape=(1, descriptor_dim)),
        next_state_desc=jnp.zeros(shape=(1, descriptor_dim)),
    )
    return dummy_transition
ReplayBuffer (PyTreeNode) dataclass

A replay buffer where transitions are flattened before being stored. Transitions are unflatenned on the fly when sampled in the buffer. data shape: (buffer_size, transition_concat_shape)

Source code in qdax/core/neuroevolution/buffers/buffer.py
class ReplayBuffer(flax.struct.PyTreeNode):
    """
    A replay buffer where transitions are flattened before being stored.
    Transitions are unflatenned on the fly when sampled in the buffer.
    data shape: (buffer_size, transition_concat_shape)
    """

    data: jnp.ndarray
    buffer_size: int = flax.struct.field(pytree_node=False)
    transition: Transition

    current_position: jnp.ndarray = flax.struct.field()
    current_size: jnp.ndarray = flax.struct.field()

    @classmethod
    def init(
        cls,
        buffer_size: int,
        transition: Transition,
    ) -> ReplayBuffer:
        """
        The constructor of the buffer.

        Note: We have to define a classmethod instead of just doing it in post_init
        because post_init is called every time the dataclass is tree_mapped. This is a
        workaround proposed in https://github.com/google/flax/issues/1628.

        Args:
            buffer_size: the size of the replay buffer, e.g. 1e6
            transition: a transition object (might be a dummy one) to get
                the dimensions right
        """
        flatten_dim = transition.flatten_dim
        data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
        current_size = jnp.array(0, dtype=int)
        current_position = jnp.array(0, dtype=int)
        return cls(
            data=data,
            current_size=current_size,
            current_position=current_position,
            buffer_size=buffer_size,
            transition=transition,
        )

    @partial(jax.jit, static_argnames=("sample_size",))
    def sample(
        self,
        random_key: RNGKey,
        sample_size: int,
    ) -> Tuple[Transition, RNGKey]:
        """
        Sample a batch of transitions in the replay buffer.
        """
        random_key, subkey = jax.random.split(random_key)
        idx = jax.random.randint(
            subkey,
            shape=(sample_size,),
            minval=0,
            maxval=self.current_size,
        )
        samples = jnp.take(self.data, idx, axis=0, mode="clip")
        transitions = self.transition.__class__.from_flatten(samples, self.transition)
        return transitions, random_key

    @jax.jit
    def insert(self, transitions: Transition) -> ReplayBuffer:
        """
        Insert a batch of transitions in the replay buffer. The transitions are
        flattened before insertion.

        Args:
            transitions: A transition object in which each field is assumed to have
                a shape (batch_size, field_dim).
        """
        flattened_transitions = transitions.flatten()
        flattened_transitions = flattened_transitions.reshape(
            (-1, flattened_transitions.shape[-1])
        )
        num_transitions = flattened_transitions.shape[0]
        max_replay_size = self.buffer_size

        # Make sure update is not larger than the maximum replay size.
        if num_transitions > max_replay_size:
            raise ValueError(
                "Trying to insert a batch of samples larger than the maximum replay "
                f"size. num_samples: {num_transitions}, "
                f"max replay size {max_replay_size}"
            )

        # get current position
        position = self.current_position

        # check if there is an overlap
        roll = jnp.minimum(0, max_replay_size - position - num_transitions)

        # roll the data to avoid overlap
        data = jnp.roll(self.data, roll, axis=0)

        # update the position accordingly
        new_position = position + roll

        # replace old data by the new one
        new_data = jax.lax.dynamic_update_slice_in_dim(
            data,
            flattened_transitions,
            start_index=new_position,
            axis=0,
        )

        # update the position and the size
        new_position = (new_position + num_transitions) % max_replay_size
        new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)

        # update the replay buffer
        replay_buffer = self.replace(
            current_position=new_position,
            current_size=new_size,
            data=new_data,
        )

        return replay_buffer  # type: ignore
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
init(buffer_size, transition) classmethod

The constructor of the buffer.

Note: We have to define a classmethod instead of just doing it in post_init because post_init is called every time the dataclass is tree_mapped. This is a workaround proposed in https://github.com/google/flax/issues/1628.

Parameters:
  • buffer_size (int) – the size of the replay buffer, e.g. 1e6

  • transition (Transition) – a transition object (might be a dummy one) to get the dimensions right

Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init(
    cls,
    buffer_size: int,
    transition: Transition,
) -> ReplayBuffer:
    """
    The constructor of the buffer.

    Note: We have to define a classmethod instead of just doing it in post_init
    because post_init is called every time the dataclass is tree_mapped. This is a
    workaround proposed in https://github.com/google/flax/issues/1628.

    Args:
        buffer_size: the size of the replay buffer, e.g. 1e6
        transition: a transition object (might be a dummy one) to get
            the dimensions right
    """
    flatten_dim = transition.flatten_dim
    data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
    current_size = jnp.array(0, dtype=int)
    current_position = jnp.array(0, dtype=int)
    return cls(
        data=data,
        current_size=current_size,
        current_position=current_position,
        buffer_size=buffer_size,
        transition=transition,
    )
sample(self, random_key, sample_size)

Sample a batch of transitions in the replay buffer.

Source code in qdax/core/neuroevolution/buffers/buffer.py
@partial(jax.jit, static_argnames=("sample_size",))
def sample(
    self,
    random_key: RNGKey,
    sample_size: int,
) -> Tuple[Transition, RNGKey]:
    """
    Sample a batch of transitions in the replay buffer.
    """
    random_key, subkey = jax.random.split(random_key)
    idx = jax.random.randint(
        subkey,
        shape=(sample_size,),
        minval=0,
        maxval=self.current_size,
    )
    samples = jnp.take(self.data, idx, axis=0, mode="clip")
    transitions = self.transition.__class__.from_flatten(samples, self.transition)
    return transitions, random_key
insert(self, transitions)

Insert a batch of transitions in the replay buffer. The transitions are flattened before insertion.

Parameters:
  • transitions (Transition) – A transition object in which each field is assumed to have a shape (batch_size, field_dim).

Source code in qdax/core/neuroevolution/buffers/buffer.py
@jax.jit
def insert(self, transitions: Transition) -> ReplayBuffer:
    """
    Insert a batch of transitions in the replay buffer. The transitions are
    flattened before insertion.

    Args:
        transitions: A transition object in which each field is assumed to have
            a shape (batch_size, field_dim).
    """
    flattened_transitions = transitions.flatten()
    flattened_transitions = flattened_transitions.reshape(
        (-1, flattened_transitions.shape[-1])
    )
    num_transitions = flattened_transitions.shape[0]
    max_replay_size = self.buffer_size

    # Make sure update is not larger than the maximum replay size.
    if num_transitions > max_replay_size:
        raise ValueError(
            "Trying to insert a batch of samples larger than the maximum replay "
            f"size. num_samples: {num_transitions}, "
            f"max replay size {max_replay_size}"
        )

    # get current position
    position = self.current_position

    # check if there is an overlap
    roll = jnp.minimum(0, max_replay_size - position - num_transitions)

    # roll the data to avoid overlap
    data = jnp.roll(self.data, roll, axis=0)

    # update the position accordingly
    new_position = position + roll

    # replace old data by the new one
    new_data = jax.lax.dynamic_update_slice_in_dim(
        data,
        flattened_transitions,
        start_index=new_position,
        axis=0,
    )

    # update the position and the size
    new_position = (new_position + num_transitions) % max_replay_size
    new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)

    # update the replay buffer
    replay_buffer = self.replace(
        current_position=new_position,
        current_size=new_size,
        data=new_data,
    )

    return replay_buffer  # type: ignore

trajectory_buffer

TrajectoryBuffer (PyTreeNode) dataclass

A buffer that stores transitions in the form of trajectories. Like FlatBuffer transitions are flattened before being stored and unflattened on the fly and the data shape is: (buffer_size, transition_concat_shape). The speicificity lies in the additional episodic data buffer that maps transitions that belong to the same trajectory to their position in the main buffer.

Examples:

Assume we have a buffer of size 6, we insert 3 transitions at a time (env_batch_size=3) and the episode length is 2. The dones data we insert is dones=[0,1,0].

Data (dones):
    [ 0.  1.  0. nan nan nan] # We inserted [0,1,0] contiguously
Episodic data:
    [[ 0. nan] # For episode 0, first timestep is at index 0 in the buffer
    [ 1. nan]  # For episode 1, first timestep is at index 1 in the buffer
    [ 2. nan]] # For episode 2, first timestep is at index 2 in the buffer
Trajectory positions:
    [0. 1. 0.] # For episode 0 and 2, done=0 so we stay in the same episode,
               # for episode 1, done=1 so we move to the next episode index
Timestep positions:
    [1. 0. 1.] # For episode 0 and 2: done=0 so we increment the timestep count-
               # er, for episode 1: done=1 so we reset the timestep counter

Now we subsequently add dones=[1,1,1] Data (dones): [0. 1. 0. 1. 1. 1.] Episodic data: [[ 0. 3.] [ 4. nan] # Episode 1 has been reset [ 2. 5.]] Trajectory positions: [1. 2. 1.] Timestep positions: [0. 0. 0.] # All timestep counters have been reset

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
class TrajectoryBuffer(struct.PyTreeNode):
    """
    A buffer that stores transitions in the form of trajectories. Like `FlatBuffer`
    transitions are flattened before being stored and unflattened on the fly and the
    data shape is: (buffer_size, transition_concat_shape).
    The speicificity lies in the additional episodic data buffer that maps transitions
    that belong to the same trajectory to their position in the main buffer.

    Example:
    Assume we have a buffer of size 6, we insert 3 transitions at a time
    (env_batch_size=3) and the episode length is 2.
    The `dones` data we insert is dones=[0,1,0].

        Data (dones):
            [ 0.  1.  0. nan nan nan] # We inserted [0,1,0] contiguously
        Episodic data:
            [[ 0. nan] # For episode 0, first timestep is at index 0 in the buffer
            [ 1. nan]  # For episode 1, first timestep is at index 1 in the buffer
            [ 2. nan]] # For episode 2, first timestep is at index 2 in the buffer
        Trajectory positions:
            [0. 1. 0.] # For episode 0 and 2, done=0 so we stay in the same episode,
                       # for episode 1, done=1 so we move to the next episode index
        Timestep positions:
            [1. 0. 1.] # For episode 0 and 2: done=0 so we increment the timestep count-
                       # er, for episode 1: done=1 so we reset the timestep counter


    Now we subsequently add dones=[1,1,1]
        Data (dones):
            [0. 1. 0. 1. 1. 1.]
        Episodic data:
            [[ 0.  3.]
            [ 4. nan] # Episode 1 has been reset
            [ 2.  5.]]
        Trajectory positions:
            [1. 2. 1.]
        Timestep positions:
            [0. 0. 0.] # All timestep counters have been reset
    """

    data: jnp.ndarray
    buffer_size: int = struct.field(pytree_node=False)
    transition: Transition

    episode_length: int = struct.field(pytree_node=False)
    env_batch_size: int = struct.field(pytree_node=False)
    num_trajectories: int = struct.field(pytree_node=False)

    current_position: jnp.ndarray = struct.field()
    current_size: jnp.ndarray = struct.field()
    trajectory_positions: jnp.ndarray = struct.field()
    timestep_positions: jnp.ndarray = struct.field()
    episodic_data: jnp.ndarray = struct.field()
    current_episodic_data_size: jnp.ndarray = struct.field()
    returns: jnp.ndarray = struct.field()

    @classmethod
    def init(  # type: ignore
        cls,
        buffer_size: int,
        transition: Transition,
        env_batch_size: int,
        episode_length: int,
    ) -> TrajectoryBuffer:
        """
        The constructor of the buffer.

        Note: We have to define a classmethod instead of just doing it in post_init
        because post_init is called every time the dataclass is tree_mapped. This is a
        workaround proposed in https://github.com/google/flax/issues/1628.
        """
        flatten_dim = transition.flatten_dim
        data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
        num_trajectories = buffer_size // episode_length
        assert (
            num_trajectories % env_batch_size == 0
        ), "num_trajectories must be a multiple of env batch size"
        assert (
            buffer_size % episode_length == 0
        ), "buffer_size must be a multiple of episode_length"
        current_position = jnp.zeros((), dtype=int)
        trajectory_positions = jnp.zeros(env_batch_size, dtype=int)
        timestep_positions = jnp.zeros(env_batch_size, dtype=int)
        episodic_data = jnp.ones((num_trajectories, episode_length)) * jnp.nan
        current_size = jnp.array(0, dtype=int)
        current_episodic_data_size = jnp.array(0, dtype=int)
        returns = jnp.ones(
            buffer_size + 1,
        ) * (-jnp.inf)
        return cls(
            data=data,
            current_position=current_position,
            buffer_size=buffer_size,
            timestep_positions=timestep_positions,
            trajectory_positions=trajectory_positions,
            episode_length=episode_length,
            env_batch_size=env_batch_size,
            episodic_data=episodic_data,
            num_trajectories=num_trajectories,
            current_size=current_size,
            current_episodic_data_size=current_episodic_data_size,
            transition=transition,
            returns=returns,
        )

    @partial(jax.jit, static_argnames=("sample_size"))
    def sample(
        self,
        random_key: RNGKey,
        sample_size: int,
    ) -> Tuple[Transition, RNGKey]:
        """
        Sample transitions from the buffer. If sample_traj=False, returns stacked
        transitions in the shape (sample_size,), if sample_traj=True, return transitions
        in the shape (sample_size, episode_length,).
        """

        # Here we want to sample single transitions
        # We sample uniformly at random the indexes of valid transitions
        random_key, subkey = jax.random.split(random_key)
        idx = jax.random.randint(
            subkey,
            shape=(sample_size,),
            minval=0,
            maxval=self.current_size,
        )
        # We select the corresponding transitions
        samples = jnp.take(self.data, idx, axis=0, mode="clip")

        # (sample_size, concat_dim)
        transitions = self.transition.__class__.from_flatten(samples, self.transition)
        return transitions, random_key

    def sample_with_returns(
        self,
        random_key: RNGKey,
        sample_size: int,
    ) -> Tuple[Transition, Reward, RNGKey]:
        """Sample transitions and the return corresponding to their episode. The returns
        are compute by the method `compute_returns`.

        Args:
            random_key: a random key
            sample_size: the number of transitions

        Returns:
            The transitions, the associated returns and a new random key.
        """
        # Here we want to sample single transitions
        # We sample uniformly at random the indexes of valid transitions
        random_key, subkey = jax.random.split(random_key)
        idx = jax.random.randint(
            subkey,
            shape=(sample_size,),
            minval=0,
            maxval=self.current_size,
        )
        # We select the corresponding transitions
        samples = jnp.take(self.data, idx, axis=0, mode="clip")
        returns = jnp.take(self.returns, idx, mode="clip")
        # (sample_size, concat_dim)
        transitions = self.transition.__class__.from_flatten(samples, self.transition)
        return transitions, returns, random_key

    @jax.jit
    def insert(self, transitions: Transition) -> TrajectoryBuffer:
        """
        Scan over 'insert_one_transition', to add multiple transitions.
        """

        @jax.jit
        def insert_one_transition(
            replay_buffer: TrajectoryBuffer, flattened_transitions: jnp.ndarray
        ) -> Tuple[TrajectoryBuffer, Any]:
            """
            Insert a batch (one step over all envs) of transitions in the buffer.
            """
            # Step 1: reset episodes for override
            # We start by selecting the episodes that are currently being inserted
            active_trajectories_indexes = (
                jnp.arange(self.env_batch_size, dtype=int)
                + (replay_buffer.trajectory_positions % self.num_trajectories)
                * self.env_batch_size
            ) % self.num_trajectories

            current_episodes = replay_buffer.episodic_data[active_trajectories_indexes]

            # The override condition is: "if we want to insert à timestep 0, we clear
            # the corresponding row first"
            condition = replay_buffer.timestep_positions % self.episode_length == 0

            # Clear episodes that match the condition, don't modify others
            override_episodes = jnp.where(
                jnp.expand_dims(condition, -1),
                jnp.ones_like(current_episodes) * jnp.nan,
                current_episodes,
            )

            new_episodic_data = replay_buffer.episodic_data.at[
                active_trajectories_indexes
            ].set(override_episodes)

            # Step 2: insert data in main buffer and indexes in episodic buffer
            # Apply changes in the episodic_data array

            # Insert transitions in the buffer
            new_data = jax.lax.dynamic_update_slice_in_dim(
                replay_buffer.data,
                flattened_transitions,
                start_index=replay_buffer.current_position % self.buffer_size,
                axis=0,
            )

            # We inserted from current_position to current_position + env_batch_size
            inserted_indexes = (
                jnp.arange(
                    self.env_batch_size,
                )
                + replay_buffer.current_position
            )
            # We set the indexes of inserted episodes in the episodic_data array
            new_episodic_data = new_episodic_data.at[
                active_trajectories_indexes,
                replay_buffer.timestep_positions,
            ].set(inserted_indexes)

            # Step 3: update the counters
            dones = flattened_transitions[
                :, (2 * (self.transition.observation_dim) + 1)
            ].ravel()

            # Increment the trajectory counter if done
            new_trajectory_positions = replay_buffer.trajectory_positions + dones

            # Within a trajectory, increment position if not done, else reset position
            new_timestep_positions = jnp.where(
                dones, jnp.zeros_like(dones), 1 + replay_buffer.timestep_positions
            )

            # Update the insertion position in the main buffer
            new_current_position = (
                replay_buffer.current_position + self.env_batch_size
            ) % self.buffer_size

            # Update the size counter of the main buffer
            new_current_size = jnp.minimum(
                replay_buffer.current_size + self.env_batch_size, self.buffer_size
            )

            # Update the size of the episodic data buffer
            new_current_episodic_data_size = jnp.minimum(
                jnp.min(replay_buffer.trajectory_positions + 1) * self.env_batch_size,
                self.num_trajectories,
            )

            replay_buffer = replay_buffer.replace(
                timestep_positions=jnp.array(new_timestep_positions, dtype=int),
                trajectory_positions=jnp.array(new_trajectory_positions, dtype=int),
                data=new_data,
                current_position=jnp.array(new_current_position, dtype=int),
                episodic_data=new_episodic_data,
                current_size=new_current_size,
                current_episodic_data_size=jnp.array(
                    new_current_episodic_data_size, dtype=int
                ),
            )
            return replay_buffer, None

        flattened_transitions = transitions.flatten()

        flattened_transitions = flattened_transitions.reshape(
            (-1, self.env_batch_size, flattened_transitions.shape[-1])
        )

        replay_buffer, _ = jax.lax.scan(
            insert_one_transition,
            self,
            flattened_transitions,
        )

        replay_buffer = replay_buffer.compute_returns()
        return replay_buffer  # type: ignore

    def compute_returns(
        self,
    ) -> TrajectoryBuffer:
        """Computes the return for each episode in the buffer.

        Returns:
            The buffer with the returns updated.
        """

        reward_idx = 2 * self.transition.observation_dim
        indexes = self.episodic_data
        rewards = self.data[:, reward_idx]
        episodic_returns = jnp.where(
            jnp.isnan(indexes),
            0,
            rewards[jnp.array(jnp.nan_to_num(indexes, 0), dtype=int)],
        ).sum(axis=1)

        values = episodic_returns[:, None].repeat(self.episode_length, axis=1)
        returns = self.returns.at[
            jnp.array(jnp.nan_to_num(indexes, nan=-1), dtype=int)
        ].set(values)
        returns = returns.at[-1].set(jnp.nan)
        return self.replace(returns=returns)  # type: ignore
init(buffer_size, transition, env_batch_size, episode_length) classmethod

The constructor of the buffer.

Note: We have to define a classmethod instead of just doing it in post_init because post_init is called every time the dataclass is tree_mapped. This is a workaround proposed in https://github.com/google/flax/issues/1628.

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@classmethod
def init(  # type: ignore
    cls,
    buffer_size: int,
    transition: Transition,
    env_batch_size: int,
    episode_length: int,
) -> TrajectoryBuffer:
    """
    The constructor of the buffer.

    Note: We have to define a classmethod instead of just doing it in post_init
    because post_init is called every time the dataclass is tree_mapped. This is a
    workaround proposed in https://github.com/google/flax/issues/1628.
    """
    flatten_dim = transition.flatten_dim
    data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
    num_trajectories = buffer_size // episode_length
    assert (
        num_trajectories % env_batch_size == 0
    ), "num_trajectories must be a multiple of env batch size"
    assert (
        buffer_size % episode_length == 0
    ), "buffer_size must be a multiple of episode_length"
    current_position = jnp.zeros((), dtype=int)
    trajectory_positions = jnp.zeros(env_batch_size, dtype=int)
    timestep_positions = jnp.zeros(env_batch_size, dtype=int)
    episodic_data = jnp.ones((num_trajectories, episode_length)) * jnp.nan
    current_size = jnp.array(0, dtype=int)
    current_episodic_data_size = jnp.array(0, dtype=int)
    returns = jnp.ones(
        buffer_size + 1,
    ) * (-jnp.inf)
    return cls(
        data=data,
        current_position=current_position,
        buffer_size=buffer_size,
        timestep_positions=timestep_positions,
        trajectory_positions=trajectory_positions,
        episode_length=episode_length,
        env_batch_size=env_batch_size,
        episodic_data=episodic_data,
        num_trajectories=num_trajectories,
        current_size=current_size,
        current_episodic_data_size=current_episodic_data_size,
        transition=transition,
        returns=returns,
    )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)
sample(self, random_key, sample_size)

Sample transitions from the buffer. If sample_traj=False, returns stacked transitions in the shape (sample_size,), if sample_traj=True, return transitions in the shape (sample_size, episode_length,).

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@partial(jax.jit, static_argnames=("sample_size"))
def sample(
    self,
    random_key: RNGKey,
    sample_size: int,
) -> Tuple[Transition, RNGKey]:
    """
    Sample transitions from the buffer. If sample_traj=False, returns stacked
    transitions in the shape (sample_size,), if sample_traj=True, return transitions
    in the shape (sample_size, episode_length,).
    """

    # Here we want to sample single transitions
    # We sample uniformly at random the indexes of valid transitions
    random_key, subkey = jax.random.split(random_key)
    idx = jax.random.randint(
        subkey,
        shape=(sample_size,),
        minval=0,
        maxval=self.current_size,
    )
    # We select the corresponding transitions
    samples = jnp.take(self.data, idx, axis=0, mode="clip")

    # (sample_size, concat_dim)
    transitions = self.transition.__class__.from_flatten(samples, self.transition)
    return transitions, random_key
sample_with_returns(self, random_key, sample_size)

Sample transitions and the return corresponding to their episode. The returns are compute by the method compute_returns.

Parameters:
  • random_key (RNGKey) – a random key

  • sample_size (int) – the number of transitions

Returns:
  • Tuple[Transition, Reward, RNGKey] – The transitions, the associated returns and a new random key.

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def sample_with_returns(
    self,
    random_key: RNGKey,
    sample_size: int,
) -> Tuple[Transition, Reward, RNGKey]:
    """Sample transitions and the return corresponding to their episode. The returns
    are compute by the method `compute_returns`.

    Args:
        random_key: a random key
        sample_size: the number of transitions

    Returns:
        The transitions, the associated returns and a new random key.
    """
    # Here we want to sample single transitions
    # We sample uniformly at random the indexes of valid transitions
    random_key, subkey = jax.random.split(random_key)
    idx = jax.random.randint(
        subkey,
        shape=(sample_size,),
        minval=0,
        maxval=self.current_size,
    )
    # We select the corresponding transitions
    samples = jnp.take(self.data, idx, axis=0, mode="clip")
    returns = jnp.take(self.returns, idx, mode="clip")
    # (sample_size, concat_dim)
    transitions = self.transition.__class__.from_flatten(samples, self.transition)
    return transitions, returns, random_key
insert(self, transitions)

Scan over 'insert_one_transition', to add multiple transitions.

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@jax.jit
def insert(self, transitions: Transition) -> TrajectoryBuffer:
    """
    Scan over 'insert_one_transition', to add multiple transitions.
    """

    @jax.jit
    def insert_one_transition(
        replay_buffer: TrajectoryBuffer, flattened_transitions: jnp.ndarray
    ) -> Tuple[TrajectoryBuffer, Any]:
        """
        Insert a batch (one step over all envs) of transitions in the buffer.
        """
        # Step 1: reset episodes for override
        # We start by selecting the episodes that are currently being inserted
        active_trajectories_indexes = (
            jnp.arange(self.env_batch_size, dtype=int)
            + (replay_buffer.trajectory_positions % self.num_trajectories)
            * self.env_batch_size
        ) % self.num_trajectories

        current_episodes = replay_buffer.episodic_data[active_trajectories_indexes]

        # The override condition is: "if we want to insert à timestep 0, we clear
        # the corresponding row first"
        condition = replay_buffer.timestep_positions % self.episode_length == 0

        # Clear episodes that match the condition, don't modify others
        override_episodes = jnp.where(
            jnp.expand_dims(condition, -1),
            jnp.ones_like(current_episodes) * jnp.nan,
            current_episodes,
        )

        new_episodic_data = replay_buffer.episodic_data.at[
            active_trajectories_indexes
        ].set(override_episodes)

        # Step 2: insert data in main buffer and indexes in episodic buffer
        # Apply changes in the episodic_data array

        # Insert transitions in the buffer
        new_data = jax.lax.dynamic_update_slice_in_dim(
            replay_buffer.data,
            flattened_transitions,
            start_index=replay_buffer.current_position % self.buffer_size,
            axis=0,
        )

        # We inserted from current_position to current_position + env_batch_size
        inserted_indexes = (
            jnp.arange(
                self.env_batch_size,
            )
            + replay_buffer.current_position
        )
        # We set the indexes of inserted episodes in the episodic_data array
        new_episodic_data = new_episodic_data.at[
            active_trajectories_indexes,
            replay_buffer.timestep_positions,
        ].set(inserted_indexes)

        # Step 3: update the counters
        dones = flattened_transitions[
            :, (2 * (self.transition.observation_dim) + 1)
        ].ravel()

        # Increment the trajectory counter if done
        new_trajectory_positions = replay_buffer.trajectory_positions + dones

        # Within a trajectory, increment position if not done, else reset position
        new_timestep_positions = jnp.where(
            dones, jnp.zeros_like(dones), 1 + replay_buffer.timestep_positions
        )

        # Update the insertion position in the main buffer
        new_current_position = (
            replay_buffer.current_position + self.env_batch_size
        ) % self.buffer_size

        # Update the size counter of the main buffer
        new_current_size = jnp.minimum(
            replay_buffer.current_size + self.env_batch_size, self.buffer_size
        )

        # Update the size of the episodic data buffer
        new_current_episodic_data_size = jnp.minimum(
            jnp.min(replay_buffer.trajectory_positions + 1) * self.env_batch_size,
            self.num_trajectories,
        )

        replay_buffer = replay_buffer.replace(
            timestep_positions=jnp.array(new_timestep_positions, dtype=int),
            trajectory_positions=jnp.array(new_trajectory_positions, dtype=int),
            data=new_data,
            current_position=jnp.array(new_current_position, dtype=int),
            episodic_data=new_episodic_data,
            current_size=new_current_size,
            current_episodic_data_size=jnp.array(
                new_current_episodic_data_size, dtype=int
            ),
        )
        return replay_buffer, None

    flattened_transitions = transitions.flatten()

    flattened_transitions = flattened_transitions.reshape(
        (-1, self.env_batch_size, flattened_transitions.shape[-1])
    )

    replay_buffer, _ = jax.lax.scan(
        insert_one_transition,
        self,
        flattened_transitions,
    )

    replay_buffer = replay_buffer.compute_returns()
    return replay_buffer  # type: ignore
compute_returns(self)

Computes the return for each episode in the buffer.

Returns:
  • TrajectoryBuffer – The buffer with the returns updated.

Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def compute_returns(
    self,
) -> TrajectoryBuffer:
    """Computes the return for each episode in the buffer.

    Returns:
        The buffer with the returns updated.
    """

    reward_idx = 2 * self.transition.observation_dim
    indexes = self.episodic_data
    rewards = self.data[:, reward_idx]
    episodic_returns = jnp.where(
        jnp.isnan(indexes),
        0,
        rewards[jnp.array(jnp.nan_to_num(indexes, 0), dtype=int)],
    ).sum(axis=1)

    values = episodic_returns[:, None].repeat(self.episode_length, axis=1)
    returns = self.returns.at[
        jnp.array(jnp.nan_to_num(indexes, nan=-1), dtype=int)
    ].set(values)
    returns = returns.at[-1].set(jnp.nan)
    return self.replace(returns=returns)  # type: ignore

losses special

dads_loss

make_dads_loss_fn(policy_fn, critic_fn, dynamics_fn, parametric_action_distribution, reward_scaling, discount, action_size, num_skills)

Creates the loss used in DADS.

Parameters:
  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – the apply function of the critic

  • dynamics_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array, jax.Array], jax.Array]) – the apply function of the dynamics network

  • parametric_action_distribution (ParametricDistribution) – the distribution over action

  • reward_scaling (float) – a multiplicative factor to the reward

  • discount (float) – the discount factor

  • action_size (int) – the size of the environment's action space

  • num_skills (int) – the number of skills set

Returns:
  • Tuple[Callable[[jax.Array, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array]] – the loss of the entropy parameter auto-tuning the loss of the policy the loss of the critic the loss of the dynamics network

Source code in qdax/core/neuroevolution/losses/dads_loss.py
def make_dads_loss_fn(
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    dynamics_fn: Callable[
        [Params, StateDescriptor, Skill, StateDescriptor], jnp.ndarray
    ],
    parametric_action_distribution: ParametricDistribution,
    reward_scaling: float,
    discount: float,
    action_size: int,
    num_skills: int,
) -> Tuple[
    Callable[[jnp.ndarray, Params, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, jnp.ndarray, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, Params, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, QDTransition, RNGKey], jnp.ndarray],
]:
    """Creates the loss used in DADS.

    Args:
        policy_fn: the apply function of the policy
        critic_fn: the apply function of the critic
        dynamics_fn: the apply function of the dynamics network
        parametric_action_distribution: the distribution over action
        reward_scaling: a multiplicative factor to the reward
        discount: the discount factor
        action_size: the size of the environment's action space
        num_skills: the number of skills set

    Returns:
        the loss of the entropy parameter auto-tuning
        the loss of the policy
        the loss of the critic
        the loss of the dynamics network
    """

    _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn = make_sac_loss_fn(
        policy_fn=policy_fn,
        critic_fn=critic_fn,
        reward_scaling=reward_scaling,
        discount=discount,
        action_size=action_size,
        parametric_action_distribution=parametric_action_distribution,
    )

    _dynamics_loss_fn = functools.partial(
        dads_dynamics_loss_fn, dynamics_fn=dynamics_fn, num_skills=num_skills
    )

    return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn, _dynamics_loss_fn
dads_dynamics_loss_fn(dynamics_params, dynamics_fn, num_skills, transitions)

Computes the loss used to train the dynamics network.

Parameters:
  • dynamics_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the parameters of the neural network used to predict the dynamics.

  • dynamics_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array, jax.Array], jax.Array]) – the apply function of the dynamics network

  • num_skills (int) – the number of skills.

  • transitions (QDTransition) – the batch of transitions used to train. They have been sampled from a replay buffer beforehand.

Returns:
  • Array – The loss obtained on the batch of transitions.

Source code in qdax/core/neuroevolution/losses/dads_loss.py
def dads_dynamics_loss_fn(
    dynamics_params: Params,
    dynamics_fn: Callable[
        [Params, StateDescriptor, Skill, StateDescriptor], jnp.ndarray
    ],
    num_skills: int,
    transitions: QDTransition,
) -> jnp.ndarray:
    """Computes the loss used to train the dynamics network.

    Args:
        dynamics_params: the parameters of the neural network
            used to predict the dynamics.
        dynamics_fn: the apply function of the dynamics network
        num_skills: the number of skills.
        transitions: the batch of transitions used to train. They
            have been sampled from a replay buffer beforehand.

    Returns:
        The loss obtained on the batch of transitions.
    """

    active_skills = transitions.obs[:, -num_skills:]
    target = transitions.next_state_desc
    log_prob = dynamics_fn(  # type: ignore
        dynamics_params,
        obs=transitions.state_desc,
        skill=active_skills,
        target=target,
    )

    # prevent training on malformed target
    loss = -jnp.mean(log_prob * (1 - transitions.dones))
    return loss

diayn_loss

make_diayn_loss_fn(policy_fn, critic_fn, discriminator_fn, parametric_action_distribution, reward_scaling, discount, action_size, num_skills)

Creates the loss used in DIAYN.

Parameters:
  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – the apply function of the critic

  • discriminator_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the discriminator

  • parametric_action_distribution (ParametricDistribution) – the distribution over actions

  • reward_scaling (float) – a multiplicative factor to the reward

  • discount (float) – the discount factor

  • action_size (int) – the size of the environment's action space

  • num_skills (int) – the number of skills set

Returns:
  • Tuple[Callable[[jax.Array, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.QDTransition, jax.Array], jax.Array]] – the loss of the entropy parameter auto-tuning the loss of the policy the loss of the critic the loss of the discriminator

Source code in qdax/core/neuroevolution/losses/diayn_loss.py
def make_diayn_loss_fn(
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    discriminator_fn: Callable[[Params, StateDescriptor], jnp.ndarray],
    parametric_action_distribution: ParametricDistribution,
    reward_scaling: float,
    discount: float,
    action_size: int,
    num_skills: int,
) -> Tuple[
    Callable[[jnp.ndarray, Params, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, jnp.ndarray, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, Params, QDTransition, RNGKey], jnp.ndarray],
    Callable[[Params, QDTransition, RNGKey], jnp.ndarray],
]:
    """Creates the loss used in DIAYN.

    Args:
        policy_fn: the apply function of the policy
        critic_fn: the apply function of the critic
        discriminator_fn: the apply function of the discriminator
        parametric_action_distribution: the distribution over actions
        reward_scaling: a multiplicative factor to the reward
        discount: the discount factor
        action_size: the size of the environment's action space
        num_skills: the number of skills set

    Returns:
        the loss of the entropy parameter auto-tuning
        the loss of the policy
        the loss of the critic
        the loss of the discriminator
    """

    _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn = make_sac_loss_fn(
        policy_fn=policy_fn,
        critic_fn=critic_fn,
        reward_scaling=reward_scaling,
        discount=discount,
        action_size=action_size,
        parametric_action_distribution=parametric_action_distribution,
    )

    _discriminator_loss_fn = functools.partial(
        diayn_discriminator_loss_fn,
        discriminator_fn=discriminator_fn,
        num_skills=num_skills,
    )

    return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn, _discriminator_loss_fn
diayn_discriminator_loss_fn(discriminator_params, discriminator_fn, num_skills, transitions)

Computes the loss used to train the discriminator. The discriminator is trained to predict the skill that has been used to generate transitions. In our case, skills are one hot encoded, the discriminator is hence trained like a multi-label classifier, using the categorical cross entropy loss.

In this loss, log softmax outputs the log probabilities for each skill. By multiplying by the skills (that are one hot vectors), we filter to keep only the log probability of the true skill.

Parameters:
  • discriminator_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – the parameters of the neural network used to discriminate the skills.

  • discriminator_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the discriminator.

  • num_skills (int) – the number of skills.

  • transitions (QDTransition) – the transitions sampled from the replay buffer.

Returns:
  • Array – The loss of the discriminator on those transitions.

Source code in qdax/core/neuroevolution/losses/diayn_loss.py
def diayn_discriminator_loss_fn(
    discriminator_params: Params,
    discriminator_fn: Callable[[Params, StateDescriptor], jnp.ndarray],
    num_skills: int,
    transitions: QDTransition,
) -> jnp.ndarray:
    """Computes the loss used to train the discriminator. The
    discriminator is trained to predict the skill that has been
    used to generate transitions. In our case, skills are one
    hot encoded, the discriminator is hence trained like a
    multi-label classifier, using the categorical cross entropy
    loss.

    In this loss, log softmax outputs the log probabilities for
    each skill. By multiplying by the skills (that are one hot
    vectors), we filter to keep only the log probability of the
    true skill.

    Args:
        discriminator_params: the parameters of the neural network
            used to discriminate the skills.
        discriminator_fn: the apply function of the discriminator.
        num_skills: the number of skills.
        transitions: the transitions sampled from the replay buffer.

    Returns:
        The loss of the discriminator on those transitions.
    """

    state_desc = transitions.state_desc
    skills = transitions.obs[:, -num_skills:]
    logits = jnp.sum(
        jax.nn.log_softmax(discriminator_fn(discriminator_params, state_desc)) * skills,
        axis=1,
    )

    loss = -jnp.mean(logits)
    return loss

sac_loss

make_sac_loss_fn(policy_fn, critic_fn, parametric_action_distribution, reward_scaling, discount, action_size)

Creates the loss used in SAC.

Parameters:
  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – the apply function of the critic

  • parametric_action_distribution (ParametricDistribution) – the distribution over actions

  • reward_scaling (float) – a multiplicative factor to the reward

  • discount (float) – the discount factor

  • action_size (int) – the size of the environment's action space

Returns:
  • Tuple[Callable[[jax.Array, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.Transition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, qdax.core.neuroevolution.buffers.buffer.Transition, jax.Array], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.Transition, jax.Array], jax.Array]] – the loss of the entropy parameter auto-tuning the loss of the policy the loss of the critic

Source code in qdax/core/neuroevolution/losses/sac_loss.py
def make_sac_loss_fn(
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    parametric_action_distribution: ParametricDistribution,
    reward_scaling: float,
    discount: float,
    action_size: int,
) -> Tuple[
    Callable[[jnp.ndarray, Params, Transition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, jnp.ndarray, Transition, RNGKey], jnp.ndarray],
    Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray],
]:
    """Creates the loss used in SAC.

    Args:
        policy_fn: the apply function of the policy
        critic_fn: the apply function of the critic
        parametric_action_distribution: the distribution over actions
        reward_scaling: a multiplicative factor to the reward
        discount: the discount factor
        action_size: the size of the environment's action space

    Returns:
        the loss of the entropy parameter auto-tuning
        the loss of the policy
        the loss of the critic
    """

    _policy_loss_fn = functools.partial(
        sac_policy_loss_fn,
        policy_fn=policy_fn,
        critic_fn=critic_fn,
        parametric_action_distribution=parametric_action_distribution,
    )

    _critic_loss_fn = functools.partial(
        sac_critic_loss_fn,
        policy_fn=policy_fn,
        critic_fn=critic_fn,
        parametric_action_distribution=parametric_action_distribution,
        reward_scaling=reward_scaling,
        discount=discount,
    )

    _alpha_loss_fn = functools.partial(
        sac_alpha_loss_fn,
        policy_fn=policy_fn,
        parametric_action_distribution=parametric_action_distribution,
        action_size=action_size,
    )

    return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn
sac_policy_loss_fn(policy_params, policy_fn, critic_fn, parametric_action_distribution, critic_params, alpha, transitions, random_key)

Creates the policy loss used in SAC.

Parameters:
  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the policy

  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – the apply function of the critic

  • parametric_action_distribution (ParametricDistribution) – the distribution over actions

  • critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the critic

  • alpha (Array) – entropy coefficient value

  • transitions (Transition) – transitions collected by the agent

  • random_key (Array) – random key

Returns:
  • Array – the loss of the policy

Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_policy_loss_fn(
    policy_params: Params,
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    parametric_action_distribution: ParametricDistribution,
    critic_params: Params,
    alpha: jnp.ndarray,
    transitions: Transition,
    random_key: RNGKey,
) -> jnp.ndarray:
    """
    Creates the policy loss used in SAC.

    Args:
        policy_params: parameters of the policy
        policy_fn: the apply function of the policy
        critic_fn: the apply function of the critic
        parametric_action_distribution: the distribution over actions
        critic_params: parameters of the critic
        alpha: entropy coefficient value
        transitions: transitions collected by the agent
        random_key: random key

    Returns:
        the loss of the policy
    """

    dist_params = policy_fn(policy_params, transitions.obs)
    action = parametric_action_distribution.sample_no_postprocessing(
        dist_params, random_key
    )
    log_prob = parametric_action_distribution.log_prob(dist_params, action)
    action = parametric_action_distribution.postprocess(action)
    q_action = critic_fn(critic_params, transitions.obs, action)
    min_q = jnp.min(q_action, axis=-1)
    actor_loss = alpha * log_prob - min_q

    return jnp.mean(actor_loss)
sac_critic_loss_fn(critic_params, policy_fn, critic_fn, parametric_action_distribution, reward_scaling, discount, policy_params, target_critic_params, alpha, transitions, random_key)

Creates the critic loss used in SAC.

Parameters:
  • critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the critic

  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – the apply function of the critic

  • parametric_action_distribution (ParametricDistribution) – the distribution over actions

  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the policy

  • target_critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the target critic

  • alpha (Array) – entropy coefficient value

  • transitions (Transition) – transitions collected by the agent

  • random_key (Array) – random key

  • reward_scaling (float) – a multiplicative factor to the reward

  • discount (float) – the discount factor

Returns:
  • Array – the loss of the critic

Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_critic_loss_fn(
    critic_params: Params,
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    parametric_action_distribution: ParametricDistribution,
    reward_scaling: float,
    discount: float,
    policy_params: Params,
    target_critic_params: Params,
    alpha: jnp.ndarray,
    transitions: Transition,
    random_key: RNGKey,
) -> jnp.ndarray:
    """
    Creates the critic loss used in SAC.

    Args:
        critic_params: parameters of the critic
        policy_fn: the apply function of the policy
        critic_fn: the apply function of the critic
        parametric_action_distribution: the distribution over actions
        policy_params: parameters of the policy
        target_critic_params: parameters of the target critic
        alpha: entropy coefficient value
        transitions: transitions collected by the agent
        random_key: random key
        reward_scaling: a multiplicative factor to the reward
        discount: the discount factor

    Returns:
        the loss of the critic
    """

    q_old_action = critic_fn(critic_params, transitions.obs, transitions.actions)
    next_dist_params = policy_fn(policy_params, transitions.next_obs)
    next_action = parametric_action_distribution.sample_no_postprocessing(
        next_dist_params, random_key
    )
    next_log_prob = parametric_action_distribution.log_prob(
        next_dist_params, next_action
    )
    next_action = parametric_action_distribution.postprocess(next_action)
    next_q = critic_fn(target_critic_params, transitions.next_obs, next_action)

    next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob

    target_q = jax.lax.stop_gradient(
        transitions.rewards * reward_scaling
        + (1.0 - transitions.dones) * discount * next_v
    )

    q_error = q_old_action - jnp.expand_dims(target_q, -1)
    q_error *= jnp.expand_dims(1 - transitions.truncations, -1)
    q_loss = 0.5 * jnp.mean(jnp.square(q_error))

    return q_loss
sac_alpha_loss_fn(log_alpha, policy_fn, parametric_action_distribution, action_size, policy_params, transitions, random_key)

Creates the alpha loss used in SAC. Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.

Parameters:
  • log_alpha (Array) – entropy coefficient log value

  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – the apply function of the policy

  • parametric_action_distribution (ParametricDistribution) – the distribution over actions

  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – parameters of the policy

  • transitions (Transition) – transitions collected by the agent

  • random_key (Array) – random key

  • action_size (int) – the size of the environment's action space

Returns:
  • Array – the loss of the entropy parameter auto-tuning

Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_alpha_loss_fn(
    log_alpha: jnp.ndarray,
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    parametric_action_distribution: ParametricDistribution,
    action_size: int,
    policy_params: Params,
    transitions: Transition,
    random_key: RNGKey,
) -> jnp.ndarray:
    """
    Creates the alpha loss used in SAC.
    Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.

    Args:
        log_alpha: entropy coefficient log value
        policy_fn: the apply function of the policy
        parametric_action_distribution: the distribution over actions
        policy_params: parameters of the policy
        transitions: transitions collected by the agent
        random_key: random key
        action_size: the size of the environment's action space

    Returns:
        the loss of the entropy parameter auto-tuning
    """

    target_entropy = -0.5 * action_size

    dist_params = policy_fn(policy_params, transitions.obs)
    action = parametric_action_distribution.sample_no_postprocessing(
        dist_params, random_key
    )
    log_prob = parametric_action_distribution.log_prob(dist_params, action)
    alpha = jnp.exp(log_alpha)
    alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy)

    loss = jnp.mean(alpha_loss)
    return loss

td3_loss

Implements a function to create critic and actor losses for the TD3 algorithm.

make_td3_loss_fn(policy_fn, critic_fn, reward_scaling, discount, noise_clip, policy_noise)

Creates the loss functions for TD3.

Parameters:
  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – forward pass through the neural network defining the policy.

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – forward pass through the neural network defining the critic.

  • reward_scaling (float) – value to multiply the reward given by the environment.

  • discount (float) – discount factor.

  • noise_clip (float) – value that clips the noise to avoid extreme values.

  • policy_noise (float) – noise applied to smooth the bootstrapping.

Returns:
  • Tuple[Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.Transition], jax.Array], Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], qdax.core.neuroevolution.buffers.buffer.Transition, jax.Array], jax.Array]] – Return the loss functions used to train the policy and the critic in TD3.

Source code in qdax/core/neuroevolution/losses/td3_loss.py
def make_td3_loss_fn(
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    reward_scaling: float,
    discount: float,
    noise_clip: float,
    policy_noise: float,
) -> Tuple[
    Callable[[Params, Params, Transition], jnp.ndarray],
    Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray],
]:
    """Creates the loss functions for TD3.

    Args:
        policy_fn: forward pass through the neural network defining the policy.
        critic_fn: forward pass through the neural network defining the critic.
        reward_scaling: value to multiply the reward given by the environment.
        discount: discount factor.
        noise_clip: value that clips the noise to avoid extreme values.
        policy_noise: noise applied to smooth the bootstrapping.

    Returns:
        Return the loss functions used to train the policy and the critic in TD3.
    """

    @jax.jit
    def _policy_loss_fn(
        policy_params: Params,
        critic_params: Params,
        transitions: Transition,
    ) -> jnp.ndarray:
        """Policy loss function for TD3 agent"""

        action = policy_fn(policy_params, transitions.obs)
        q_value = critic_fn(
            critic_params, obs=transitions.obs, actions=action  # type: ignore
        )
        q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1)
        policy_loss = -jnp.mean(q1_action)
        return policy_loss

    @jax.jit
    def _critic_loss_fn(
        critic_params: Params,
        target_policy_params: Params,
        target_critic_params: Params,
        transitions: Transition,
        random_key: RNGKey,
    ) -> jnp.ndarray:
        """Critics loss function for TD3 agent"""
        noise = (
            jax.random.normal(random_key, shape=transitions.actions.shape)
            * policy_noise
        ).clip(-noise_clip, noise_clip)

        next_action = (
            policy_fn(target_policy_params, transitions.next_obs) + noise
        ).clip(-1.0, 1.0)
        next_q = critic_fn(  # type: ignore
            target_critic_params, obs=transitions.next_obs, actions=next_action
        )
        next_v = jnp.min(next_q, axis=-1)
        target_q = jax.lax.stop_gradient(
            transitions.rewards * reward_scaling
            + (1.0 - transitions.dones) * discount * next_v
        )
        q_old_action = critic_fn(  # type: ignore
            critic_params,
            obs=transitions.obs,
            actions=transitions.actions,
        )
        q_error = q_old_action - jnp.expand_dims(target_q, -1)

        # Better bootstrapping for truncated episodes.
        q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1)

        # compute the loss
        q_losses = jnp.mean(jnp.square(q_error), axis=-2)
        q_loss = jnp.sum(q_losses, axis=-1)

        return q_loss

    return _policy_loss_fn, _critic_loss_fn
td3_policy_loss_fn(policy_params, critic_params, policy_fn, critic_fn, transitions)

Policy loss function for TD3 agent.

Parameters:
  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – policy parameters.

  • critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – critic parameters.

  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – forward pass through the neural network defining the policy.

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – forward pass through the neural network defining the critic.

  • transitions (Transition) – collected transitions.

Returns:
  • Array – Return the loss function used to train the policy in TD3.

Source code in qdax/core/neuroevolution/losses/td3_loss.py
def td3_policy_loss_fn(
    policy_params: Params,
    critic_params: Params,
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    transitions: Transition,
) -> jnp.ndarray:
    """Policy loss function for TD3 agent.

    Args:
        policy_params: policy parameters.
        critic_params: critic parameters.
        policy_fn: forward pass through the neural network defining the policy.
        critic_fn: forward pass through the neural network defining the critic.
        transitions: collected transitions.

    Returns:
        Return the loss function used to train the policy in TD3.
    """

    action = policy_fn(policy_params, transitions.obs)
    q_value = critic_fn(
        critic_params, obs=transitions.obs, actions=action  # type: ignore
    )
    q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1)
    policy_loss = -jnp.mean(q1_action)
    return policy_loss
td3_critic_loss_fn(critic_params, target_policy_params, target_critic_params, policy_fn, critic_fn, policy_noise, noise_clip, reward_scaling, discount, transitions, random_key)

Critics loss function for TD3 agent.

Parameters:
  • critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – critic parameters.

  • target_policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – target policy parameters.

  • target_critic_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – target critic parameters.

  • policy_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], jax.Array]) – forward pass through the neural network defining the policy.

  • critic_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, jax.Array], jax.Array]) – forward pass through the neural network defining the critic.

  • policy_noise (float) – policy noise.

  • noise_clip (float) – noise clip.

  • reward_scaling (float) – reward scaling coefficient.

  • discount (float) – discount factor.

  • transitions (Transition) – collected transitions.

Returns:
  • Array – Return the loss function used to train the critic in TD3.

Source code in qdax/core/neuroevolution/losses/td3_loss.py
def td3_critic_loss_fn(
    critic_params: Params,
    target_policy_params: Params,
    target_critic_params: Params,
    policy_fn: Callable[[Params, Observation], jnp.ndarray],
    critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
    policy_noise: float,
    noise_clip: float,
    reward_scaling: float,
    discount: float,
    transitions: Transition,
    random_key: RNGKey,
) -> jnp.ndarray:
    """Critics loss function for TD3 agent.

    Args:
        critic_params: critic parameters.
        target_policy_params: target policy parameters.
        target_critic_params: target critic parameters.
        policy_fn: forward pass through the neural network defining the policy.
        critic_fn: forward pass through the neural network defining the critic.
        policy_noise: policy noise.
        noise_clip: noise clip.
        reward_scaling: reward scaling coefficient.
        discount: discount factor.
        transitions: collected transitions.

    Returns:
        Return the loss function used to train the critic in TD3.
    """
    noise = (
        jax.random.normal(random_key, shape=transitions.actions.shape) * policy_noise
    ).clip(-noise_clip, noise_clip)

    next_action = (policy_fn(target_policy_params, transitions.next_obs) + noise).clip(
        -1.0, 1.0
    )
    next_q = critic_fn(  # type: ignore
        target_critic_params, obs=transitions.next_obs, actions=next_action
    )
    next_v = jnp.min(next_q, axis=-1)
    target_q = jax.lax.stop_gradient(
        transitions.rewards * reward_scaling
        + (1.0 - transitions.dones) * discount * next_v
    )
    q_old_action = critic_fn(  # type: ignore
        critic_params,
        obs=transitions.obs,
        actions=transitions.actions,
    )
    q_error = q_old_action - jnp.expand_dims(target_q, -1)

    # Better bootstrapping for truncated episodes.
    q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1)

    # compute the loss
    q_losses = jnp.mean(jnp.square(q_error), axis=-2)
    q_loss = jnp.sum(q_losses, axis=-1)

    return q_loss

networks special

dads_networks

GaussianMixture (Module)

Module that outputs a Gaussian Mixture Distribution.

Source code in qdax/core/neuroevolution/networks/dads_networks.py
class GaussianMixture(hk.Module):
    """Module that outputs a Gaussian Mixture Distribution."""

    def __init__(
        self,
        num_dimensions: int,
        num_components: int,
        reinterpreted_batch_ndims: Optional[int] = None,
        identity_covariance: bool = True,
        initializer: Optional[Initializer] = None,
        name: str = "GaussianMixture",
    ):
        """Module that outputs a Gaussian Mixture Distribution
        with identity covariance matrix."""

        super().__init__(name=name)
        if initializer is None:
            initializer = VarianceScaling(1.0, "fan_in", "uniform")
        self._num_dimensions = num_dimensions
        self._num_components = num_components
        self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
        self._identity_covariance = identity_covariance
        self.initializer = initializer
        logits_size = self._num_components

        self.logit_layer = hk.Linear(logits_size, w_init=self.initializer)

        # Create two layers that outputs a location and a scale, respectively, for
        # each dimension and each component.
        self.loc_layer = hk.Linear(
            self._num_dimensions * self._num_components, w_init=self.initializer
        )
        if not self._identity_covariance:
            self.scale_layer = hk.Linear(
                self._num_dimensions * self._num_components, w_init=self.initializer
            )

    def __call__(self, inputs: jnp.ndarray) -> tfp.distributions.Distribution:
        # Compute logits, locs, and scales if necessary.
        logits = self.logit_layer(inputs)
        locs = self.loc_layer(inputs)

        shape = [-1, self._num_components, self._num_dimensions]  # [B, D, C]

        # Reshape the mixture's location and scale parameters appropriately.
        locs = locs.reshape(shape)
        if not self._identity_covariance:

            scales = self.scale_layer(inputs)
            scales = scales.reshape(shape)
        else:
            scales = jnp.ones_like(locs)

        # Create the mixture distribution
        components = tfp.distributions.MultivariateNormalDiag(
            loc=locs, scale_diag=scales
        )
        mixture = tfp.distributions.Categorical(logits=logits)
        distribution = tfp.distributions.MixtureSameFamily(
            mixture_distribution=mixture, components_distribution=components
        )

        return distribution
__init__(self, num_dimensions, num_components, reinterpreted_batch_ndims=None, identity_covariance=True, initializer=None, name='GaussianMixture') special

Module that outputs a Gaussian Mixture Distribution with identity covariance matrix.

Source code in qdax/core/neuroevolution/networks/dads_networks.py
def __init__(
    self,
    num_dimensions: int,
    num_components: int,
    reinterpreted_batch_ndims: Optional[int] = None,
    identity_covariance: bool = True,
    initializer: Optional[Initializer] = None,
    name: str = "GaussianMixture",
):
    """Module that outputs a Gaussian Mixture Distribution
    with identity covariance matrix."""

    super().__init__(name=name)
    if initializer is None:
        initializer = VarianceScaling(1.0, "fan_in", "uniform")
    self._num_dimensions = num_dimensions
    self._num_components = num_components
    self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
    self._identity_covariance = identity_covariance
    self.initializer = initializer
    logits_size = self._num_components

    self.logit_layer = hk.Linear(logits_size, w_init=self.initializer)

    # Create two layers that outputs a location and a scale, respectively, for
    # each dimension and each component.
    self.loc_layer = hk.Linear(
        self._num_dimensions * self._num_components, w_init=self.initializer
    )
    if not self._identity_covariance:
        self.scale_layer = hk.Linear(
            self._num_dimensions * self._num_components, w_init=self.initializer
        )
DynamicsNetwork (Module)

Dynamics network (used in DADS).

Source code in qdax/core/neuroevolution/networks/dads_networks.py
class DynamicsNetwork(hk.Module):
    """Dynamics network (used in DADS)."""

    def __init__(
        self,
        hidden_layer_sizes: tuple,
        output_size: int,
        omit_input_dynamics_dim: int = 2,
        name: Optional[str] = None,
        identity_covariance: bool = True,
        initializer: Optional[Initializer] = None,
    ):
        super().__init__(name=name)
        if initializer is None:
            initializer = VarianceScaling(1.0, "fan_in", "uniform")

        self.distribution = GaussianMixture(
            output_size,
            num_components=4,
            reinterpreted_batch_ndims=None,
            identity_covariance=identity_covariance,
            initializer=initializer,
        )
        self.network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(hidden_layer_sizes),
                    w_init=initializer,
                    activation=jax.nn.relu,
                    activate_final=True,
                ),
            ]
        )
        self._omit_input_dynamics_dim = omit_input_dynamics_dim

    def __call__(
        self, obs: StateDescriptor, skill: Skill, target: StateDescriptor
    ) -> jnp.ndarray:
        """Normalizes the observation, predicts a distribution probability conditioned
        on (obs,skill) and returns the log_prob of the target.
        """

        obs = obs[:, self._omit_input_dynamics_dim :]
        obs = jnp.concatenate((obs, skill), axis=1)
        out = self.network(obs)
        dist = self.distribution(out)
        return dist.log_prob(target)
make_dads_networks(action_size, descriptor_size, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256), omit_input_dynamics_dim=2, identity_covariance=True, dynamics_initializer=None)

Creates networks used in DADS.

Parameters:
  • action_size (int) – the size of the environment's action space

  • descriptor_size (int) – the size of the environment's descriptor space (i.e. the dimension of the dynamics network's input)

  • hidden_layer_sizes – the number of neurons for hidden layers. Defaults to (256, 256).

  • omit_input_dynamics_dim (int) – how many descriptors we omit when creating the input of the dynamics networks. Defaults to 2.

  • identity_covariance (bool) – whether to fix the covariance matrix of the Gaussian models to identity. Defaults to True.

  • dynamics_initializer (Optional[Callable[[Sequence[int], Any], jax.Array]]) – the initializer of the dynamics layers. Defaults to None.

Returns:
  • Tuple[haiku._src.transform.Transformed, haiku._src.transform.Transformed, haiku._src.transform.Transformed] – the policy network the critic network the dynamics network

Source code in qdax/core/neuroevolution/networks/dads_networks.py
def make_dads_networks(
    action_size: int,
    descriptor_size: int,
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
    policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
    omit_input_dynamics_dim: int = 2,
    identity_covariance: bool = True,
    dynamics_initializer: Optional[Initializer] = None,
) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]:
    """Creates networks used in DADS.

    Args:
        action_size: the size of the environment's action space
        descriptor_size: the size of the environment's descriptor space (i.e. the
            dimension of the dynamics network's input)
        hidden_layer_sizes: the number of neurons for hidden layers.
            Defaults to (256, 256).
        omit_input_dynamics_dim: how many descriptors we omit when creating the input
            of the dynamics networks. Defaults to 2.
        identity_covariance: whether to fix the covariance matrix of the Gaussian models
            to identity. Defaults to True.
        dynamics_initializer: the initializer of the dynamics layers. Defaults to None.

    Returns:
        the policy network
        the critic network
        the dynamics network
    """

    def _actor_fn(obs: Observation) -> jnp.ndarray:
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(policy_hidden_layer_size) + [2 * action_size],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        return network(obs)

    def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
        network1 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        network2 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    def _dynamics_fn(
        obs: StateDescriptor, skill: Skill, target: StateDescriptor
    ) -> jnp.ndarray:
        dynamics_network = DynamicsNetwork(
            critic_hidden_layer_size,
            descriptor_size,
            omit_input_dynamics_dim=omit_input_dynamics_dim,
            identity_covariance=identity_covariance,
            initializer=dynamics_initializer,
        )
        return dynamics_network(obs, skill, target)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))
    dynamics = hk.without_apply_rng(hk.transform(_dynamics_fn))

    return policy, critic, dynamics

diayn_networks

make_diayn_networks(action_size, num_skills, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256))

Creates networks used in DIAYN.

Parameters:
  • action_size (int) – the size of the environment's action space

  • num_skills (int) – the number of skills set

  • hidden_layer_sizes – the number of neurons for hidden layers. Defaults to (256, 256).

Returns:
  • Tuple[haiku._src.transform.Transformed, haiku._src.transform.Transformed, haiku._src.transform.Transformed] – the policy network the critic network the discriminator network

Source code in qdax/core/neuroevolution/networks/diayn_networks.py
def make_diayn_networks(
    action_size: int,
    num_skills: int,
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
    policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]:
    """Creates networks used in DIAYN.

    Args:
        action_size: the size of the environment's action space
        num_skills: the number of skills set
        hidden_layer_sizes: the number of neurons for hidden layers.
            Defaults to (256, 256).

    Returns:
        the policy network
        the critic network
        the discriminator network
    """

    def _actor_fn(obs: Observation) -> jnp.ndarray:
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(policy_hidden_layer_size) + [2 * action_size],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        return network(obs)

    def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
        network1 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        network2 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    def _discriminator_fn(obs: Observation) -> jnp.ndarray:
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [num_skills],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        return network(obs)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))
    discriminator = hk.without_apply_rng(hk.transform(_discriminator_fn))

    return policy, critic, discriminator

networks

Implements neural networks models that are commonly found in the RL literature.

QModule (Module) dataclass

Q Module.

Source code in qdax/core/neuroevolution/networks/networks.py
class QModule(nn.Module):
    """Q Module."""

    hidden_layer_sizes: Tuple[int, ...]
    n_critics: int = 2

    @nn.compact
    def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        hidden = jnp.concatenate([obs, actions], axis=-1)
        res = []
        for _ in range(self.n_critics):
            q = networks.MLP(
                layer_sizes=self.hidden_layer_sizes + (1,),
                activation=nn.relu,
                kernel_init=jax.nn.initializers.lecun_uniform(),
            )(hidden)
            res.append(q)
        return jnp.concatenate(res, axis=-1)
MLP (Module) dataclass

MLP module.

Source code in qdax/core/neuroevolution/networks/networks.py
class MLP(nn.Module):
    """MLP module."""

    layer_sizes: Tuple[int, ...]
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform()
    final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None
    bias: bool = True
    kernel_init_final: Optional[Callable[..., Any]] = None

    @nn.compact
    def __call__(self, data: jnp.ndarray) -> jnp.ndarray:
        hidden = data
        for i, hidden_size in enumerate(self.layer_sizes):

            if i != len(self.layer_sizes) - 1:
                hidden = nn.Dense(
                    hidden_size,
                    # name=f"hidden_{i}", with this version of flax, changing the name
                    # changes the initialization
                    kernel_init=self.kernel_init,
                    use_bias=self.bias,
                )(hidden)
                hidden = self.activation(hidden)  # type: ignore

            else:
                if self.kernel_init_final is not None:
                    kernel_init = self.kernel_init_final
                else:
                    kernel_init = self.kernel_init

                hidden = nn.Dense(
                    hidden_size,
                    # name=f"hidden_{i}",
                    kernel_init=kernel_init,
                    use_bias=self.bias,
                )(hidden)

                if self.final_activation is not None:
                    hidden = self.final_activation(hidden)

        return hidden

sac_networks

make_sac_networks(action_size, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256))

Creates networks used in SAC.

Parameters:
  • action_size (int) – the size of the environment's action space

  • hidden_layer_sizes – the number of neurons for hidden layers. Defaults to (256, 256).

Returns:
  • Tuple[haiku._src.transform.Transformed, haiku._src.transform.Transformed] – the policy network the critic network

Source code in qdax/core/neuroevolution/networks/sac_networks.py
def make_sac_networks(
    action_size: int,
    critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
    policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
) -> Tuple[hk.Transformed, hk.Transformed]:
    """Creates networks used in SAC.

    Args:
        action_size: the size of the environment's action space
        hidden_layer_sizes: the number of neurons for hidden layers.
            Defaults to (256, 256).

    Returns:
        the policy network
        the critic network
    """

    def _actor_fn(obs: Observation) -> jnp.ndarray:
        network = hk.Sequential(
            [
                hk.nets.MLP(
                    list(policy_hidden_layer_size) + [2 * action_size],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        return network(obs)

    def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
        network1 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        network2 = hk.Sequential(
            [
                hk.nets.MLP(
                    list(critic_hidden_layer_size) + [1],
                    w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
                    activation=jax.nn.relu,
                ),
            ]
        )
        input_ = jnp.concatenate([obs, action], axis=-1)
        value1 = network1(input_)
        value2 = network2(input_)
        return jnp.concatenate([value1, value2], axis=-1)

    policy = hk.without_apply_rng(hk.transform(_actor_fn))
    critic = hk.without_apply_rng(hk.transform(_critic_fn))

    return policy, critic

seq2seq_networks

seq2seq example: Mode code.

Inspired by Flax library - https://github.com/google/flax/blob/main/examples/seq2seq/models.py

Copyright 2022 The Flax Authors. Licensed under the Apache License, Version 2.0 (the "License")

EncoderLSTM (Module) dataclass

EncoderLSTM Module wrapped in a lifted scan transform.

Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class EncoderLSTM(nn.Module):
    """EncoderLSTM Module wrapped in a lifted scan transform."""

    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=1,
        out_axes=1,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(
        self, carry: Tuple[Array, Array], x: Array
    ) -> Tuple[Tuple[Array, Array], Array]:
        """Applies the module."""
        lstm_state, is_eos = carry
        features = lstm_state[0].shape[-1]
        new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)

        def select_carried_state(new_state: Array, old_state: Array) -> Array:
            return jnp.where(is_eos[:, np.newaxis], old_state, new_state)

        # LSTM state is a tuple (c, h).
        carried_lstm_state = tuple(
            select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)
        )

        return (carried_lstm_state, is_eos), y

    @staticmethod
    def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
        # Use a dummy key since the default state init fn is just zeros.
        return nn.LSTMCell(hidden_size, parent=None).initialize_carry(  # type: ignore
            jax.random.PRNGKey(0), (batch_size, hidden_size)
        )
Encoder (Module) dataclass

LSTM encoder, returning state after finding the EOS token in the input.

Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Encoder(nn.Module):
    """LSTM encoder, returning state after finding the EOS token in the input."""

    hidden_size: int

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        batch_size = inputs.shape[0]
        lstm = EncoderLSTM(name="encoder_lstm")
        init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)

        # We use the `is_eos` array to determine whether the encoder should carry
        # over the last lstm state, or apply the LSTM cell on the previous state.
        init_is_eos = jnp.zeros(batch_size, dtype=bool)
        init_carry = (init_lstm_state, init_is_eos)
        (final_state, _), _ = lstm(init_carry, inputs)

        return final_state
DecoderLSTM (Module) dataclass

DecoderLSTM Module wrapped in a lifted scan transform.

Attributes:

Name Type Description
teacher_force bool

See docstring on Seq2seq module.

obs_size int

Size of the observations.

Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class DecoderLSTM(nn.Module):
    """DecoderLSTM Module wrapped in a lifted scan transform.

    Attributes:
      teacher_force: See docstring on Seq2seq module.
      obs_size: Size of the observations.
    """

    teacher_force: bool
    obs_size: int

    @functools.partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=1,
        out_axes=1,
        split_rngs={"params": False, "lstm": True},
    )
    @nn.compact
    def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
        """Applies the DecoderLSTM model."""

        lstm_state, last_prediction = carry
        if not self.teacher_force:
            x = last_prediction

        features = lstm_state[0].shape[-1]
        new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)

        logits = nn.Dense(features=self.obs_size)(y)

        return (lstm_state, logits), (logits, logits)
Decoder (Module) dataclass

LSTM decoder.

Attributes:

Name Type Description
init_state

[batch_size, hidden_size] Initial state of the decoder (i.e., the final state of the encoder).

teacher_force bool

See docstring on Seq2seq module.

obs_size int

Size of the observations.

Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Decoder(nn.Module):
    """LSTM decoder.

    Attributes:
      init_state: [batch_size, hidden_size]
        Initial state of the decoder (i.e., the final state of the encoder).
      teacher_force: See docstring on Seq2seq module.
      obs_size: Size of the observations.
    """

    teacher_force: bool
    obs_size: int

    @nn.compact
    def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]:
        """Applies the decoder model.

        Args:
          inputs: [batch_size, max_output_len-1, obs_size]
            Contains the inputs to the decoder at each time step (only used when not
            using teacher forcing). Since each token at position i is fed as input
            to the decoder at position i+1, the last token is not provided.

        Returns:
          Pair (logits, predictions), which are two arrays of respectively decoded
          logits and predictions (in one hot-encoding format).
        """
        lstm = DecoderLSTM(teacher_force=self.teacher_force, obs_size=self.obs_size)
        init_carry = (init_state, inputs[:, 0])
        _, (logits, predictions) = lstm(init_carry, inputs)
        return logits, predictions
Seq2seq (Module) dataclass

Sequence-to-sequence class using encoder/decoder architecture.

Attributes:

Name Type Description
teacher_force bool

whether to use decoder_inputs as input to the decoder at every step. If False, only the first input (i.e., the "=" token) is used, followed by samples taken from the previous output logits.

hidden_size int

int, the number of hidden dimensions in the encoder and decoder LSTMs.

obs_size int

the size of the observations.

eos_id

EOS id.

Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Seq2seq(nn.Module):
    """Sequence-to-sequence class using encoder/decoder architecture.

    Attributes:
      teacher_force: whether to use `decoder_inputs` as input to the decoder at
        every step. If False, only the first input (i.e., the "=" token) is used,
        followed by samples taken from the previous output logits.
      hidden_size: int, the number of hidden dimensions in the encoder and decoder
        LSTMs.
      obs_size: the size of the observations.
      eos_id: EOS id.
    """

    teacher_force: bool
    hidden_size: int
    obs_size: int

    def setup(self) -> None:
        self.encoder = Encoder(hidden_size=self.hidden_size)
        self.decoder = Decoder(teacher_force=self.teacher_force, obs_size=self.obs_size)

    @nn.compact
    def __call__(
        self, encoder_inputs: Array, decoder_inputs: Array
    ) -> Tuple[Array, Array]:
        """Applies the seq2seq model.

        Args:
          encoder_inputs: [batch_size, max_input_length, obs_size].
            padded batch of input sequences to encode.
          decoder_inputs: [batch_size, max_output_length, obs_size].
            padded batch of expected decoded sequences for teacher forcing.
            When sampling (i.e., `teacher_force = False`), only the first token is
            input into the decoder (which is the token "="), and samples are used
            for the following inputs. The second dimension of this tensor determines
            how many steps will be decoded, regardless of the value of
            `teacher_force`.

        Returns:
          Pair (logits, predictions), which are two arrays of length `batch_size`
          containing respectively decoded logits and predictions (in one hot
          encoding format).
        """
        # encode inputs
        init_decoder_state = self.encoder(encoder_inputs)

        # decode outputs
        logits, predictions = self.decoder(decoder_inputs, init_decoder_state)

        return logits, predictions

    def encode(self, encoder_inputs: Array) -> Array:
        # encode inputs
        init_decoder_state = self.encoder(encoder_inputs)
        final_output, _hidden_state = init_decoder_state
        return final_output

td3_networks

Implements a function to create neural networks for the TD3 algorithm.

make_td3_networks(action_size, critic_hidden_layer_sizes, policy_hidden_layer_sizes)

Creates networks used by the TD3 agent.

Parameters:
  • action_size (int) – Size the action array used to interact with the environment.

  • critic_hidden_layer_sizes (Tuple[int, ...]) – Number of layers and units per layer used in the neural network defining the critic.

  • policy_hidden_layer_sizes (Tuple[int, ...]) – Number of layers and units per layer used in the neural network defining the policy.

Returns:
  • Tuple[qdax.core.neuroevolution.networks.networks.MLP, qdax.core.neuroevolution.networks.networks.QModule] – The neural network defining the policy and the module defininf the critic. This module contains two neural networks.

Source code in qdax/core/neuroevolution/networks/td3_networks.py
def make_td3_networks(
    action_size: int,
    critic_hidden_layer_sizes: Tuple[int, ...],
    policy_hidden_layer_sizes: Tuple[int, ...],
) -> Tuple[MLP, QModule]:
    """Creates networks used by the TD3 agent.

    Args:
        action_size: Size the action array used to interact with the environment.
        critic_hidden_layer_sizes: Number of layers and units per layer used in the
            neural network defining the critic.
        policy_hidden_layer_sizes: Number of layers and units per layer used in the
            neural network defining the policy.

    Returns:
        The neural network defining the policy and the module defininf the critic.
        This module contains two neural networks.
    """

    # Instantiate policy and critics networks
    policy_layer_sizes = policy_hidden_layer_sizes + (action_size,)
    policy_network = MLP(
        layer_sizes=policy_layer_sizes,
        final_activation=jnp.tanh,
    )
    q_network = QModule(n_critics=2, hidden_layer_sizes=critic_hidden_layer_sizes)

    return (policy_network, q_network)

mdp_utils

TrainingState (PyTreeNode) dataclass

The state of a training process. Can be used to store anything that is useful for a training process. This object is used in the package to store all stateful object necessary for training an agent that learns how to act in an MDP.

Source code in qdax/core/neuroevolution/mdp_utils.py
class TrainingState(PyTreeNode):
    """The state of a training process. Can be used to store anything
    that is useful for a training process. This object is used in the
    package to store all stateful object necessary for training an agent
    that learns how to act in an MDP.
    """

    pass
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/core/neuroevolution/mdp_utils.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

generate_unroll(init_state, policy_params, random_key, episode_length, play_step_fn)

Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode.

Parameters:
  • init_state (State) – first state of the rollout.

  • policy_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – params of the individual.

  • random_key (Array) – random key for stochasiticity handling.

  • episode_length (int) – length of the rollout.

  • play_step_fn (Callable[[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[brax.envs.base.State, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array, qdax.core.neuroevolution.buffers.buffer.Transition]]) – function describing how a step need to be taken.

Returns:
  • Tuple[brax.envs.base.State, qdax.core.neuroevolution.buffers.buffer.Transition] – A new state, the experienced transition.

Source code in qdax/core/neuroevolution/mdp_utils.py
@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
    init_state: EnvState,
    policy_params: Params,
    random_key: RNGKey,
    episode_length: int,
    play_step_fn: Callable[
        [EnvState, Params, RNGKey],
        Tuple[
            EnvState,
            Params,
            RNGKey,
            Transition,
        ],
    ],
) -> Tuple[EnvState, Transition]:
    """Generates an episode according to the agent's policy, returns the final state of
    the episode and the transitions of the episode.

    Args:
        init_state: first state of the rollout.
        policy_params: params of the individual.
        random_key: random key for stochasiticity handling.
        episode_length: length of the rollout.
        play_step_fn: function describing how a step need to be taken.

    Returns:
        A new state, the experienced transition.
    """

    def _scan_play_step_fn(
        carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any
    ) -> Tuple[Tuple[EnvState, Params, RNGKey], Transition]:
        env_state, policy_params, random_key, transitions = play_step_fn(*carry)
        return (env_state, policy_params, random_key), transitions

    (state, _, _), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (init_state, policy_params, random_key),
        (),
        length=episode_length,
    )
    return state, transitions

get_first_episode(transition)

Extracts the first episode from a batch of transitions, returns the batch of transitions that is masked with nans except for the first episode.

Source code in qdax/core/neuroevolution/mdp_utils.py
@jax.jit
def get_first_episode(transition: Transition) -> Transition:
    """Extracts the first episode from a batch of transitions, returns the batch of
    transitions that is masked with nans except for the first episode.
    """

    dones = jnp.roll(transition.dones, 1, axis=0).at[0].set(0)
    mask = 1 - jnp.clip(jnp.cumsum(dones, axis=0), 0, 1)

    def mask_episodes(x: jnp.ndarray) -> jnp.ndarray:
        # the double transpose trick is here to allow easy broadcasting
        return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T

    return jax.tree_map(mask_episodes, transition)  # type: ignore

init_population_controllers(policy_network, env, batch_size, random_key)

Initializes the population of controllers using a policy_network.

Parameters:
  • policy_network (Module) – The policy network structure used for creating policy controllers.

  • env (Env) – the BRAX environment.

  • batch_size (int) – the number of environments we play simultaneously.

  • random_key (Array) – a JAX random key.

Returns:
  • Tuple[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array] – A tuple of the initial population and the new random key.

Source code in qdax/core/neuroevolution/mdp_utils.py
def init_population_controllers(
    policy_network: nn.Module,
    env: brax.envs.Env,
    batch_size: int,
    random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
    """
    Initializes the population of controllers using a policy_network.

    Args:
        policy_network: The policy network structure used for creating policy
            controllers.
        env: the BRAX environment.
        batch_size: the number of environments we play simultaneously.
        random_key: a JAX random key.

    Returns:
        A tuple of the initial population and the new random key.
    """
    random_key, subkey = jax.random.split(random_key)

    keys = jax.random.split(subkey, num=batch_size)
    fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
    init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

    return init_variables, random_key

normalization_utils

Utilities functions to perform normalization (generally on observations in RL).

RunningMeanStdState (tuple)

Running statistics for observtations/rewards

Source code in qdax/core/neuroevolution/normalization_utils.py
class RunningMeanStdState(NamedTuple):
    """Running statistics for observtations/rewards"""

    mean: jnp.ndarray
    var: jnp.ndarray
    count: jnp.ndarray

update_running_mean_std(std_state, obs)

Update running statistics with batch of observations (Welford's algorithm)

Source code in qdax/core/neuroevolution/normalization_utils.py
def update_running_mean_std(
    std_state: RunningMeanStdState, obs: Observation
) -> RunningMeanStdState:
    """Update running statistics with batch of observations (Welford's algorithm)"""

    running_mean, running_variance, normalization_steps = std_state

    step_increment = obs.shape[0]

    total_new_steps = normalization_steps + step_increment

    # Compute the incremental update and divide by the number of new steps.
    input_to_old_mean = obs - running_mean
    mean_diff = jnp.sum(input_to_old_mean / total_new_steps, axis=0)
    new_mean = running_mean + mean_diff

    # Compute difference of input to the new mean for Welford update.
    input_to_new_mean = obs - new_mean
    var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=0)

    return RunningMeanStdState(new_mean, running_variance + var_diff, total_new_steps)

normalize_with_rmstd(obs, rmstd, std_min_value=1e-06, std_max_value=1000000.0, apply_clipping=True)

Normalize input with provided running statistics

Source code in qdax/core/neuroevolution/normalization_utils.py
def normalize_with_rmstd(
    obs: jnp.ndarray,
    rmstd: RunningMeanStdState,
    std_min_value: float = 1e-6,
    std_max_value: float = 1e6,
    apply_clipping: bool = True,
) -> jnp.ndarray:
    """Normalize input with provided running statistics"""

    running_mean, running_variance, normalization_steps = rmstd
    variance = running_variance / (normalization_steps + 1.0)
    # We clip because the running_variance can become negative,
    # presumably because of numerical instabilities.
    if apply_clipping:
        variance = jnp.clip(variance, std_min_value, std_max_value)
        return jnp.clip((obs - running_mean) / jnp.sqrt(variance), -5, 5)
    else:
        return (obs - running_mean) / jnp.sqrt(variance)

sac_td3_utils

Functions similar to the ones in mdp_utils, the main difference is the assumption that the policy parameters are included in the training state. By passing this whole training state we can update running statistics for normalization for instance.

We are currently thinking about elegant ways to unify both in order to avoid code repetition.

warmstart_buffer(replay_buffer, training_state, env_state, play_step_fn, num_warmstart_steps, env_batch_size)

Pre-populates the buffer with transitions. Returns the warmstarted buffer and the new state of the environment.

Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(
    jax.jit,
    static_argnames=(
        "num_warmstart_steps",
        "play_step_fn",
        "env_batch_size",
    ),
)
def warmstart_buffer(
    replay_buffer: ReplayBuffer,
    training_state: TrainingState,
    env_state: EnvState,
    play_step_fn: Callable[
        [EnvState, TrainingState],
        Tuple[
            EnvState,
            TrainingState,
            Transition,
        ],
    ],
    num_warmstart_steps: int,
    env_batch_size: int,
) -> Tuple[ReplayBuffer, EnvState, TrainingState]:
    """Pre-populates the buffer with transitions. Returns the warmstarted buffer
    and the new state of the environment.
    """

    def _scan_play_step_fn(
        carry: Tuple[EnvState, TrainingState], unused_arg: Any
    ) -> Tuple[Tuple[EnvState, TrainingState], Transition]:
        env_state, training_state, transitions = play_step_fn(*carry)
        return (env_state, training_state), transitions

    (env_state, training_state), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (env_state, training_state),
        (),
        length=num_warmstart_steps // env_batch_size,
    )
    replay_buffer = replay_buffer.insert(transitions)

    return replay_buffer, env_state, training_state

generate_unroll(init_state, training_state, episode_length, play_step_fn)

Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode.

Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
    init_state: EnvState,
    training_state: TrainingState,
    episode_length: int,
    play_step_fn: Callable[
        [EnvState, TrainingState],
        Tuple[
            EnvState,
            TrainingState,
            Transition,
        ],
    ],
) -> Tuple[EnvState, TrainingState, Transition]:
    """Generates an episode according to the agent's policy, returns the final state of the
    episode and the transitions of the episode.
    """

    def _scan_play_step_fn(
        carry: Tuple[EnvState, TrainingState], unused_arg: Any
    ) -> Tuple[Tuple[EnvState, TrainingState], Transition]:
        env_state, training_state, transitions = play_step_fn(*carry)
        return (env_state, training_state), transitions

    (state, training_state), transitions = jax.lax.scan(
        _scan_play_step_fn,
        (init_state, training_state),
        (),
        length=episode_length,
    )
    return state, training_state, transitions

do_iteration_fn(training_state, env_state, replay_buffer, env_batch_size, grad_updates_per_step, play_step_fn, update_fn)

Performs one environment step (over all env simultaneously) followed by one training step. The number of updates is controlled by the parameter grad_updates_per_step (0 means no update while 1 means env_batch_size updates). Returns the updated states, the updated buffer and the aggregated metrics.

Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(
    jax.jit,
    static_argnames=(
        "env_batch_size",
        "grad_updates_per_step",
        "play_step_fn",
        "update_fn",
    ),
)
def do_iteration_fn(
    training_state: TrainingState,
    env_state: EnvState,
    replay_buffer: ReplayBuffer,
    env_batch_size: int,
    grad_updates_per_step: float,
    play_step_fn: Callable[
        [EnvState, TrainingState],
        Tuple[
            EnvState,
            TrainingState,
            Transition,
        ],
    ],
    update_fn: Callable[
        [TrainingState, ReplayBuffer],
        Tuple[
            TrainingState,
            ReplayBuffer,
            Metrics,
        ],
    ],
) -> Tuple[TrainingState, EnvState, ReplayBuffer, Metrics]:
    """Performs one environment step (over all env simultaneously) followed by one
    training step. The number of updates is controlled by the parameter
    `grad_updates_per_step` (0 means no update while 1 means `env_batch_size`
    updates). Returns the updated states, the updated buffer and the aggregated
    metrics.
    """

    def _scan_update_fn(
        carry: Tuple[TrainingState, ReplayBuffer], unused_arg: Any
    ) -> Tuple[Tuple[TrainingState, ReplayBuffer], Metrics]:
        training_state, replay_buffer, metrics = update_fn(*carry)
        return (training_state, replay_buffer), metrics

    # play steps in the environment
    env_state, training_state, transitions = play_step_fn(env_state, training_state)

    # insert transitions in replay buffer
    replay_buffer = replay_buffer.insert(transitions)
    num_updates = int(grad_updates_per_step * env_batch_size)

    (training_state, replay_buffer), metrics = jax.lax.scan(
        _scan_update_fn,
        (training_state, replay_buffer),
        (),
        length=num_updates,
    )

    return training_state, env_state, replay_buffer, metrics