Neuro-evolution components¶
qdax.core.neuroevolution
special
¶
buffers
special
¶
buffer
¶
Transition (PyTreeNode)
dataclass
¶
Stores data corresponding to a transition collected by a classic RL algorithm.
Source code in qdax/core/neuroevolution/buffers/buffer.py
class Transition(flax.struct.PyTreeNode):
"""Stores data corresponding to a transition collected by a classic RL algorithm."""
obs: Observation
next_obs: Observation
rewards: Reward
dones: Done
truncations: jnp.ndarray # Indicates if an episode has reached max time step
actions: Action
@property
def observation_dim(self) -> int:
"""
Returns:
the dimension of the observation
"""
return self.obs.shape[-1] # type: ignore
@property
def action_dim(self) -> int:
"""
Returns:
the dimension of the action
"""
return self.actions.shape[-1] # type: ignore
@property
def flatten_dim(self) -> int:
"""
Returns:
the dimension of the transition once flattened.
"""
flatten_dim = 2 * self.observation_dim + self.action_dim + 3
return flatten_dim
def flatten(self) -> jnp.ndarray:
"""
Returns:
a jnp.ndarray that corresponds to the flattened transition.
"""
flatten_transition = jnp.concatenate(
[
self.obs,
self.next_obs,
jnp.expand_dims(self.rewards, axis=-1),
jnp.expand_dims(self.dones, axis=-1),
jnp.expand_dims(self.truncations, axis=-1),
self.actions,
],
axis=-1,
)
return flatten_transition
@classmethod
def from_flatten(
cls,
flattened_transition: jnp.ndarray,
transition: Transition,
) -> Transition:
"""
Creates a transition from a flattened transition in a jnp.ndarray.
Args:
flattened_transition: flattened transition in a jnp.ndarray of shape
(batch_size, flatten_dim)
transition: a transition object (might be a dummy one) to
get the dimensions right
Returns:
a Transition object
"""
obs_dim = transition.observation_dim
action_dim = transition.action_dim
obs = flattened_transition[:, :obs_dim]
next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
dones = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
)
truncations = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
)
actions = flattened_transition[
:, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
]
return cls(
obs=obs,
next_obs=next_obs,
rewards=rewards,
dones=dones,
truncations=truncations,
actions=actions,
)
@classmethod
def init_dummy(cls, observation_dim: int, action_dim: int) -> Transition:
"""
Initialize a dummy transition that then can be passed to constructors to get
all shapes right.
Args:
observation_dim: observation dimension
action_dim: action dimension
Returns:
a dummy transition
"""
dummy_transition = Transition(
obs=jnp.zeros(shape=(1, observation_dim)),
next_obs=jnp.zeros(shape=(1, observation_dim)),
rewards=jnp.zeros(shape=(1,)),
dones=jnp.zeros(shape=(1,)),
truncations=jnp.zeros(shape=(1,)),
actions=jnp.zeros(shape=(1, action_dim)),
)
return dummy_transition
observation_dim: int
property
readonly
¶
Returns: |
|
---|
action_dim: int
property
readonly
¶
Returns: |
|
---|
flatten_dim: int
property
readonly
¶
Returns: |
|
---|
flatten(self)
¶
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
def flatten(self) -> jnp.ndarray:
"""
Returns:
a jnp.ndarray that corresponds to the flattened transition.
"""
flatten_transition = jnp.concatenate(
[
self.obs,
self.next_obs,
jnp.expand_dims(self.rewards, axis=-1),
jnp.expand_dims(self.dones, axis=-1),
jnp.expand_dims(self.truncations, axis=-1),
self.actions,
],
axis=-1,
)
return flatten_transition
from_flatten(flattened_transition, transition)
classmethod
¶
Creates a transition from a flattened transition in a jnp.ndarray.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def from_flatten(
cls,
flattened_transition: jnp.ndarray,
transition: Transition,
) -> Transition:
"""
Creates a transition from a flattened transition in a jnp.ndarray.
Args:
flattened_transition: flattened transition in a jnp.ndarray of shape
(batch_size, flatten_dim)
transition: a transition object (might be a dummy one) to
get the dimensions right
Returns:
a Transition object
"""
obs_dim = transition.observation_dim
action_dim = transition.action_dim
obs = flattened_transition[:, :obs_dim]
next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
dones = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
)
truncations = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
)
actions = flattened_transition[
:, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
]
return cls(
obs=obs,
next_obs=next_obs,
rewards=rewards,
dones=dones,
truncations=truncations,
actions=actions,
)
init_dummy(observation_dim, action_dim)
classmethod
¶
Initialize a dummy transition that then can be passed to constructors to get all shapes right.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init_dummy(cls, observation_dim: int, action_dim: int) -> Transition:
"""
Initialize a dummy transition that then can be passed to constructors to get
all shapes right.
Args:
observation_dim: observation dimension
action_dim: action dimension
Returns:
a dummy transition
"""
dummy_transition = Transition(
obs=jnp.zeros(shape=(1, observation_dim)),
next_obs=jnp.zeros(shape=(1, observation_dim)),
rewards=jnp.zeros(shape=(1,)),
dones=jnp.zeros(shape=(1,)),
truncations=jnp.zeros(shape=(1,)),
actions=jnp.zeros(shape=(1, action_dim)),
)
return dummy_transition
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
QDTransition (Transition)
dataclass
¶
Stores data corresponding to a transition collected by a QD algorithm.
Source code in qdax/core/neuroevolution/buffers/buffer.py
class QDTransition(Transition):
"""Stores data corresponding to a transition collected by a QD algorithm."""
state_desc: StateDescriptor
next_state_desc: StateDescriptor
@property
def state_descriptor_dim(self) -> int:
"""
Returns:
the dimension of the state descriptors.
"""
return self.state_desc.shape[-1] # type: ignore
@property
def flatten_dim(self) -> int:
"""
Returns:
the dimension of the transition once flattened.
"""
flatten_dim = (
2 * self.observation_dim
+ self.action_dim
+ 3
+ 2 * self.state_descriptor_dim
)
return flatten_dim
def flatten(self) -> jnp.ndarray:
"""
Returns:
a jnp.ndarray that corresponds to the flattened transition.
"""
flatten_transition = jnp.concatenate(
[
self.obs,
self.next_obs,
jnp.expand_dims(self.rewards, axis=-1),
jnp.expand_dims(self.dones, axis=-1),
jnp.expand_dims(self.truncations, axis=-1),
self.actions,
self.state_desc,
self.next_state_desc,
],
axis=-1,
)
return flatten_transition
@classmethod
def from_flatten(
cls,
flattened_transition: jnp.ndarray,
transition: QDTransition,
) -> QDTransition:
"""
Creates a transition from a flattened transition in a jnp.ndarray.
Args:
flattened_transition: flattened transition in a jnp.ndarray of shape
(batch_size, flatten_dim)
transition: a transition object (might be a dummy one) to
get the dimensions right
Returns:
a Transition object
"""
obs_dim = transition.observation_dim
action_dim = transition.action_dim
desc_dim = transition.state_descriptor_dim
obs = flattened_transition[:, :obs_dim]
next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
dones = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
)
truncations = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
)
actions = flattened_transition[
:, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
]
state_desc = flattened_transition[
:,
(2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + desc_dim),
]
next_state_desc = flattened_transition[
:,
(2 * obs_dim + 3 + action_dim + desc_dim) : (
2 * obs_dim + 3 + action_dim + 2 * desc_dim
),
]
return cls(
obs=obs,
next_obs=next_obs,
rewards=rewards,
dones=dones,
truncations=truncations,
actions=actions,
state_desc=state_desc,
next_state_desc=next_state_desc,
)
@classmethod
def init_dummy( # type: ignore
cls, observation_dim: int, action_dim: int, descriptor_dim: int
) -> QDTransition:
"""
Initialize a dummy transition that then can be passed to constructors to get
all shapes right.
Args:
observation_dim: observation dimension
action_dim: action dimension
Returns:
a dummy transition
"""
dummy_transition = QDTransition(
obs=jnp.zeros(shape=(1, observation_dim)),
next_obs=jnp.zeros(shape=(1, observation_dim)),
rewards=jnp.zeros(shape=(1,)),
dones=jnp.zeros(shape=(1,)),
truncations=jnp.zeros(shape=(1,)),
actions=jnp.zeros(shape=(1, action_dim)),
state_desc=jnp.zeros(shape=(1, descriptor_dim)),
next_state_desc=jnp.zeros(shape=(1, descriptor_dim)),
)
return dummy_transition
state_descriptor_dim: int
property
readonly
¶
Returns: |
|
---|
flatten_dim: int
property
readonly
¶
Returns: |
|
---|
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
flatten(self)
¶
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
def flatten(self) -> jnp.ndarray:
"""
Returns:
a jnp.ndarray that corresponds to the flattened transition.
"""
flatten_transition = jnp.concatenate(
[
self.obs,
self.next_obs,
jnp.expand_dims(self.rewards, axis=-1),
jnp.expand_dims(self.dones, axis=-1),
jnp.expand_dims(self.truncations, axis=-1),
self.actions,
self.state_desc,
self.next_state_desc,
],
axis=-1,
)
return flatten_transition
from_flatten(flattened_transition, transition)
classmethod
¶
Creates a transition from a flattened transition in a jnp.ndarray.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def from_flatten(
cls,
flattened_transition: jnp.ndarray,
transition: QDTransition,
) -> QDTransition:
"""
Creates a transition from a flattened transition in a jnp.ndarray.
Args:
flattened_transition: flattened transition in a jnp.ndarray of shape
(batch_size, flatten_dim)
transition: a transition object (might be a dummy one) to
get the dimensions right
Returns:
a Transition object
"""
obs_dim = transition.observation_dim
action_dim = transition.action_dim
desc_dim = transition.state_descriptor_dim
obs = flattened_transition[:, :obs_dim]
next_obs = flattened_transition[:, obs_dim : (2 * obs_dim)]
rewards = jnp.ravel(flattened_transition[:, (2 * obs_dim) : (2 * obs_dim + 1)])
dones = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 1) : (2 * obs_dim + 2)]
)
truncations = jnp.ravel(
flattened_transition[:, (2 * obs_dim + 2) : (2 * obs_dim + 3)]
)
actions = flattened_transition[
:, (2 * obs_dim + 3) : (2 * obs_dim + 3 + action_dim)
]
state_desc = flattened_transition[
:,
(2 * obs_dim + 3 + action_dim) : (2 * obs_dim + 3 + action_dim + desc_dim),
]
next_state_desc = flattened_transition[
:,
(2 * obs_dim + 3 + action_dim + desc_dim) : (
2 * obs_dim + 3 + action_dim + 2 * desc_dim
),
]
return cls(
obs=obs,
next_obs=next_obs,
rewards=rewards,
dones=dones,
truncations=truncations,
actions=actions,
state_desc=state_desc,
next_state_desc=next_state_desc,
)
init_dummy(observation_dim, action_dim, descriptor_dim)
classmethod
¶
Initialize a dummy transition that then can be passed to constructors to get all shapes right.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init_dummy( # type: ignore
cls, observation_dim: int, action_dim: int, descriptor_dim: int
) -> QDTransition:
"""
Initialize a dummy transition that then can be passed to constructors to get
all shapes right.
Args:
observation_dim: observation dimension
action_dim: action dimension
Returns:
a dummy transition
"""
dummy_transition = QDTransition(
obs=jnp.zeros(shape=(1, observation_dim)),
next_obs=jnp.zeros(shape=(1, observation_dim)),
rewards=jnp.zeros(shape=(1,)),
dones=jnp.zeros(shape=(1,)),
truncations=jnp.zeros(shape=(1,)),
actions=jnp.zeros(shape=(1, action_dim)),
state_desc=jnp.zeros(shape=(1, descriptor_dim)),
next_state_desc=jnp.zeros(shape=(1, descriptor_dim)),
)
return dummy_transition
ReplayBuffer (PyTreeNode)
dataclass
¶
A replay buffer where transitions are flattened before being stored. Transitions are unflatenned on the fly when sampled in the buffer. data shape: (buffer_size, transition_concat_shape)
Source code in qdax/core/neuroevolution/buffers/buffer.py
class ReplayBuffer(flax.struct.PyTreeNode):
"""
A replay buffer where transitions are flattened before being stored.
Transitions are unflatenned on the fly when sampled in the buffer.
data shape: (buffer_size, transition_concat_shape)
"""
data: jnp.ndarray
buffer_size: int = flax.struct.field(pytree_node=False)
transition: Transition
current_position: jnp.ndarray = flax.struct.field()
current_size: jnp.ndarray = flax.struct.field()
@classmethod
def init(
cls,
buffer_size: int,
transition: Transition,
) -> ReplayBuffer:
"""
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init
because post_init is called every time the dataclass is tree_mapped. This is a
workaround proposed in https://github.com/google/flax/issues/1628.
Args:
buffer_size: the size of the replay buffer, e.g. 1e6
transition: a transition object (might be a dummy one) to get
the dimensions right
"""
flatten_dim = transition.flatten_dim
data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
current_size = jnp.array(0, dtype=int)
current_position = jnp.array(0, dtype=int)
return cls(
data=data,
current_size=current_size,
current_position=current_position,
buffer_size=buffer_size,
transition=transition,
)
@partial(jax.jit, static_argnames=("sample_size",))
def sample(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, RNGKey]:
"""
Sample a batch of transitions in the replay buffer.
"""
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
samples = jnp.take(self.data, idx, axis=0, mode="clip")
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, random_key
@jax.jit
def insert(self, transitions: Transition) -> ReplayBuffer:
"""
Insert a batch of transitions in the replay buffer. The transitions are
flattened before insertion.
Args:
transitions: A transition object in which each field is assumed to have
a shape (batch_size, field_dim).
"""
flattened_transitions = transitions.flatten()
flattened_transitions = flattened_transitions.reshape(
(-1, flattened_transitions.shape[-1])
)
num_transitions = flattened_transitions.shape[0]
max_replay_size = self.buffer_size
# Make sure update is not larger than the maximum replay size.
if num_transitions > max_replay_size:
raise ValueError(
"Trying to insert a batch of samples larger than the maximum replay "
f"size. num_samples: {num_transitions}, "
f"max replay size {max_replay_size}"
)
# get current position
position = self.current_position
# check if there is an overlap
roll = jnp.minimum(0, max_replay_size - position - num_transitions)
# roll the data to avoid overlap
data = jnp.roll(self.data, roll, axis=0)
# update the position accordingly
new_position = position + roll
# replace old data by the new one
new_data = jax.lax.dynamic_update_slice_in_dim(
data,
flattened_transitions,
start_index=new_position,
axis=0,
)
# update the position and the size
new_position = (new_position + num_transitions) % max_replay_size
new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)
# update the replay buffer
replay_buffer = self.replace(
current_position=new_position,
current_size=new_size,
data=new_data,
)
return replay_buffer # type: ignore
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/neuroevolution/buffers/buffer.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
init(buffer_size, transition)
classmethod
¶
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init because post_init is called every time the dataclass is tree_mapped. This is a workaround proposed in https://github.com/google/flax/issues/1628.
Parameters: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@classmethod
def init(
cls,
buffer_size: int,
transition: Transition,
) -> ReplayBuffer:
"""
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init
because post_init is called every time the dataclass is tree_mapped. This is a
workaround proposed in https://github.com/google/flax/issues/1628.
Args:
buffer_size: the size of the replay buffer, e.g. 1e6
transition: a transition object (might be a dummy one) to get
the dimensions right
"""
flatten_dim = transition.flatten_dim
data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
current_size = jnp.array(0, dtype=int)
current_position = jnp.array(0, dtype=int)
return cls(
data=data,
current_size=current_size,
current_position=current_position,
buffer_size=buffer_size,
transition=transition,
)
sample(self, random_key, sample_size)
¶
Sample a batch of transitions in the replay buffer.
Source code in qdax/core/neuroevolution/buffers/buffer.py
@partial(jax.jit, static_argnames=("sample_size",))
def sample(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, RNGKey]:
"""
Sample a batch of transitions in the replay buffer.
"""
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
samples = jnp.take(self.data, idx, axis=0, mode="clip")
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, random_key
insert(self, transitions)
¶
Insert a batch of transitions in the replay buffer. The transitions are flattened before insertion.
Parameters: |
|
---|
Source code in qdax/core/neuroevolution/buffers/buffer.py
@jax.jit
def insert(self, transitions: Transition) -> ReplayBuffer:
"""
Insert a batch of transitions in the replay buffer. The transitions are
flattened before insertion.
Args:
transitions: A transition object in which each field is assumed to have
a shape (batch_size, field_dim).
"""
flattened_transitions = transitions.flatten()
flattened_transitions = flattened_transitions.reshape(
(-1, flattened_transitions.shape[-1])
)
num_transitions = flattened_transitions.shape[0]
max_replay_size = self.buffer_size
# Make sure update is not larger than the maximum replay size.
if num_transitions > max_replay_size:
raise ValueError(
"Trying to insert a batch of samples larger than the maximum replay "
f"size. num_samples: {num_transitions}, "
f"max replay size {max_replay_size}"
)
# get current position
position = self.current_position
# check if there is an overlap
roll = jnp.minimum(0, max_replay_size - position - num_transitions)
# roll the data to avoid overlap
data = jnp.roll(self.data, roll, axis=0)
# update the position accordingly
new_position = position + roll
# replace old data by the new one
new_data = jax.lax.dynamic_update_slice_in_dim(
data,
flattened_transitions,
start_index=new_position,
axis=0,
)
# update the position and the size
new_position = (new_position + num_transitions) % max_replay_size
new_size = jnp.minimum(self.current_size + num_transitions, max_replay_size)
# update the replay buffer
replay_buffer = self.replace(
current_position=new_position,
current_size=new_size,
data=new_data,
)
return replay_buffer # type: ignore
trajectory_buffer
¶
TrajectoryBuffer (PyTreeNode)
dataclass
¶
A buffer that stores transitions in the form of trajectories. Like FlatBuffer
transitions are flattened before being stored and unflattened on the fly and the
data shape is: (buffer_size, transition_concat_shape).
The speicificity lies in the additional episodic data buffer that maps transitions
that belong to the same trajectory to their position in the main buffer.
Examples:
Assume we have a buffer of size 6, we insert 3 transitions at a time
(env_batch_size=3) and the episode length is 2.
The dones
data we insert is dones=[0,1,0].
Data (dones):
[ 0. 1. 0. nan nan nan] # We inserted [0,1,0] contiguously
Episodic data:
[[ 0. nan] # For episode 0, first timestep is at index 0 in the buffer
[ 1. nan] # For episode 1, first timestep is at index 1 in the buffer
[ 2. nan]] # For episode 2, first timestep is at index 2 in the buffer
Trajectory positions:
[0. 1. 0.] # For episode 0 and 2, done=0 so we stay in the same episode,
# for episode 1, done=1 so we move to the next episode index
Timestep positions:
[1. 0. 1.] # For episode 0 and 2: done=0 so we increment the timestep count-
# er, for episode 1: done=1 so we reset the timestep counter
Now we subsequently add dones=[1,1,1] Data (dones): [0. 1. 0. 1. 1. 1.] Episodic data: [[ 0. 3.] [ 4. nan] # Episode 1 has been reset [ 2. 5.]] Trajectory positions: [1. 2. 1.] Timestep positions: [0. 0. 0.] # All timestep counters have been reset
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
class TrajectoryBuffer(struct.PyTreeNode):
"""
A buffer that stores transitions in the form of trajectories. Like `FlatBuffer`
transitions are flattened before being stored and unflattened on the fly and the
data shape is: (buffer_size, transition_concat_shape).
The speicificity lies in the additional episodic data buffer that maps transitions
that belong to the same trajectory to their position in the main buffer.
Example:
Assume we have a buffer of size 6, we insert 3 transitions at a time
(env_batch_size=3) and the episode length is 2.
The `dones` data we insert is dones=[0,1,0].
Data (dones):
[ 0. 1. 0. nan nan nan] # We inserted [0,1,0] contiguously
Episodic data:
[[ 0. nan] # For episode 0, first timestep is at index 0 in the buffer
[ 1. nan] # For episode 1, first timestep is at index 1 in the buffer
[ 2. nan]] # For episode 2, first timestep is at index 2 in the buffer
Trajectory positions:
[0. 1. 0.] # For episode 0 and 2, done=0 so we stay in the same episode,
# for episode 1, done=1 so we move to the next episode index
Timestep positions:
[1. 0. 1.] # For episode 0 and 2: done=0 so we increment the timestep count-
# er, for episode 1: done=1 so we reset the timestep counter
Now we subsequently add dones=[1,1,1]
Data (dones):
[0. 1. 0. 1. 1. 1.]
Episodic data:
[[ 0. 3.]
[ 4. nan] # Episode 1 has been reset
[ 2. 5.]]
Trajectory positions:
[1. 2. 1.]
Timestep positions:
[0. 0. 0.] # All timestep counters have been reset
"""
data: jnp.ndarray
buffer_size: int = struct.field(pytree_node=False)
transition: Transition
episode_length: int = struct.field(pytree_node=False)
env_batch_size: int = struct.field(pytree_node=False)
num_trajectories: int = struct.field(pytree_node=False)
current_position: jnp.ndarray = struct.field()
current_size: jnp.ndarray = struct.field()
trajectory_positions: jnp.ndarray = struct.field()
timestep_positions: jnp.ndarray = struct.field()
episodic_data: jnp.ndarray = struct.field()
current_episodic_data_size: jnp.ndarray = struct.field()
returns: jnp.ndarray = struct.field()
@classmethod
def init( # type: ignore
cls,
buffer_size: int,
transition: Transition,
env_batch_size: int,
episode_length: int,
) -> TrajectoryBuffer:
"""
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init
because post_init is called every time the dataclass is tree_mapped. This is a
workaround proposed in https://github.com/google/flax/issues/1628.
"""
flatten_dim = transition.flatten_dim
data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
num_trajectories = buffer_size // episode_length
assert (
num_trajectories % env_batch_size == 0
), "num_trajectories must be a multiple of env batch size"
assert (
buffer_size % episode_length == 0
), "buffer_size must be a multiple of episode_length"
current_position = jnp.zeros((), dtype=int)
trajectory_positions = jnp.zeros(env_batch_size, dtype=int)
timestep_positions = jnp.zeros(env_batch_size, dtype=int)
episodic_data = jnp.ones((num_trajectories, episode_length)) * jnp.nan
current_size = jnp.array(0, dtype=int)
current_episodic_data_size = jnp.array(0, dtype=int)
returns = jnp.ones(
buffer_size + 1,
) * (-jnp.inf)
return cls(
data=data,
current_position=current_position,
buffer_size=buffer_size,
timestep_positions=timestep_positions,
trajectory_positions=trajectory_positions,
episode_length=episode_length,
env_batch_size=env_batch_size,
episodic_data=episodic_data,
num_trajectories=num_trajectories,
current_size=current_size,
current_episodic_data_size=current_episodic_data_size,
transition=transition,
returns=returns,
)
@partial(jax.jit, static_argnames=("sample_size"))
def sample(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, RNGKey]:
"""
Sample transitions from the buffer. If sample_traj=False, returns stacked
transitions in the shape (sample_size,), if sample_traj=True, return transitions
in the shape (sample_size, episode_length,).
"""
# Here we want to sample single transitions
# We sample uniformly at random the indexes of valid transitions
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
# We select the corresponding transitions
samples = jnp.take(self.data, idx, axis=0, mode="clip")
# (sample_size, concat_dim)
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, random_key
def sample_with_returns(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, Reward, RNGKey]:
"""Sample transitions and the return corresponding to their episode. The returns
are compute by the method `compute_returns`.
Args:
random_key: a random key
sample_size: the number of transitions
Returns:
The transitions, the associated returns and a new random key.
"""
# Here we want to sample single transitions
# We sample uniformly at random the indexes of valid transitions
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
# We select the corresponding transitions
samples = jnp.take(self.data, idx, axis=0, mode="clip")
returns = jnp.take(self.returns, idx, mode="clip")
# (sample_size, concat_dim)
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, returns, random_key
@jax.jit
def insert(self, transitions: Transition) -> TrajectoryBuffer:
"""
Scan over 'insert_one_transition', to add multiple transitions.
"""
@jax.jit
def insert_one_transition(
replay_buffer: TrajectoryBuffer, flattened_transitions: jnp.ndarray
) -> Tuple[TrajectoryBuffer, Any]:
"""
Insert a batch (one step over all envs) of transitions in the buffer.
"""
# Step 1: reset episodes for override
# We start by selecting the episodes that are currently being inserted
active_trajectories_indexes = (
jnp.arange(self.env_batch_size, dtype=int)
+ (replay_buffer.trajectory_positions % self.num_trajectories)
* self.env_batch_size
) % self.num_trajectories
current_episodes = replay_buffer.episodic_data[active_trajectories_indexes]
# The override condition is: "if we want to insert à timestep 0, we clear
# the corresponding row first"
condition = replay_buffer.timestep_positions % self.episode_length == 0
# Clear episodes that match the condition, don't modify others
override_episodes = jnp.where(
jnp.expand_dims(condition, -1),
jnp.ones_like(current_episodes) * jnp.nan,
current_episodes,
)
new_episodic_data = replay_buffer.episodic_data.at[
active_trajectories_indexes
].set(override_episodes)
# Step 2: insert data in main buffer and indexes in episodic buffer
# Apply changes in the episodic_data array
# Insert transitions in the buffer
new_data = jax.lax.dynamic_update_slice_in_dim(
replay_buffer.data,
flattened_transitions,
start_index=replay_buffer.current_position % self.buffer_size,
axis=0,
)
# We inserted from current_position to current_position + env_batch_size
inserted_indexes = (
jnp.arange(
self.env_batch_size,
)
+ replay_buffer.current_position
)
# We set the indexes of inserted episodes in the episodic_data array
new_episodic_data = new_episodic_data.at[
active_trajectories_indexes,
replay_buffer.timestep_positions,
].set(inserted_indexes)
# Step 3: update the counters
dones = flattened_transitions[
:, (2 * (self.transition.observation_dim) + 1)
].ravel()
# Increment the trajectory counter if done
new_trajectory_positions = replay_buffer.trajectory_positions + dones
# Within a trajectory, increment position if not done, else reset position
new_timestep_positions = jnp.where(
dones, jnp.zeros_like(dones), 1 + replay_buffer.timestep_positions
)
# Update the insertion position in the main buffer
new_current_position = (
replay_buffer.current_position + self.env_batch_size
) % self.buffer_size
# Update the size counter of the main buffer
new_current_size = jnp.minimum(
replay_buffer.current_size + self.env_batch_size, self.buffer_size
)
# Update the size of the episodic data buffer
new_current_episodic_data_size = jnp.minimum(
jnp.min(replay_buffer.trajectory_positions + 1) * self.env_batch_size,
self.num_trajectories,
)
replay_buffer = replay_buffer.replace(
timestep_positions=jnp.array(new_timestep_positions, dtype=int),
trajectory_positions=jnp.array(new_trajectory_positions, dtype=int),
data=new_data,
current_position=jnp.array(new_current_position, dtype=int),
episodic_data=new_episodic_data,
current_size=new_current_size,
current_episodic_data_size=jnp.array(
new_current_episodic_data_size, dtype=int
),
)
return replay_buffer, None
flattened_transitions = transitions.flatten()
flattened_transitions = flattened_transitions.reshape(
(-1, self.env_batch_size, flattened_transitions.shape[-1])
)
replay_buffer, _ = jax.lax.scan(
insert_one_transition,
self,
flattened_transitions,
)
replay_buffer = replay_buffer.compute_returns()
return replay_buffer # type: ignore
def compute_returns(
self,
) -> TrajectoryBuffer:
"""Computes the return for each episode in the buffer.
Returns:
The buffer with the returns updated.
"""
reward_idx = 2 * self.transition.observation_dim
indexes = self.episodic_data
rewards = self.data[:, reward_idx]
episodic_returns = jnp.where(
jnp.isnan(indexes),
0,
rewards[jnp.array(jnp.nan_to_num(indexes, 0), dtype=int)],
).sum(axis=1)
values = episodic_returns[:, None].repeat(self.episode_length, axis=1)
returns = self.returns.at[
jnp.array(jnp.nan_to_num(indexes, nan=-1), dtype=int)
].set(values)
returns = returns.at[-1].set(jnp.nan)
return self.replace(returns=returns) # type: ignore
init(buffer_size, transition, env_batch_size, episode_length)
classmethod
¶
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init because post_init is called every time the dataclass is tree_mapped. This is a workaround proposed in https://github.com/google/flax/issues/1628.
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@classmethod
def init( # type: ignore
cls,
buffer_size: int,
transition: Transition,
env_batch_size: int,
episode_length: int,
) -> TrajectoryBuffer:
"""
The constructor of the buffer.
Note: We have to define a classmethod instead of just doing it in post_init
because post_init is called every time the dataclass is tree_mapped. This is a
workaround proposed in https://github.com/google/flax/issues/1628.
"""
flatten_dim = transition.flatten_dim
data = jnp.ones((buffer_size, flatten_dim)) * jnp.nan
num_trajectories = buffer_size // episode_length
assert (
num_trajectories % env_batch_size == 0
), "num_trajectories must be a multiple of env batch size"
assert (
buffer_size % episode_length == 0
), "buffer_size must be a multiple of episode_length"
current_position = jnp.zeros((), dtype=int)
trajectory_positions = jnp.zeros(env_batch_size, dtype=int)
timestep_positions = jnp.zeros(env_batch_size, dtype=int)
episodic_data = jnp.ones((num_trajectories, episode_length)) * jnp.nan
current_size = jnp.array(0, dtype=int)
current_episodic_data_size = jnp.array(0, dtype=int)
returns = jnp.ones(
buffer_size + 1,
) * (-jnp.inf)
return cls(
data=data,
current_position=current_position,
buffer_size=buffer_size,
timestep_positions=timestep_positions,
trajectory_positions=trajectory_positions,
episode_length=episode_length,
env_batch_size=env_batch_size,
episodic_data=episodic_data,
num_trajectories=num_trajectories,
current_size=current_size,
current_episodic_data_size=current_episodic_data_size,
transition=transition,
returns=returns,
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
sample(self, random_key, sample_size)
¶
Sample transitions from the buffer. If sample_traj=False, returns stacked transitions in the shape (sample_size,), if sample_traj=True, return transitions in the shape (sample_size, episode_length,).
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@partial(jax.jit, static_argnames=("sample_size"))
def sample(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, RNGKey]:
"""
Sample transitions from the buffer. If sample_traj=False, returns stacked
transitions in the shape (sample_size,), if sample_traj=True, return transitions
in the shape (sample_size, episode_length,).
"""
# Here we want to sample single transitions
# We sample uniformly at random the indexes of valid transitions
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
# We select the corresponding transitions
samples = jnp.take(self.data, idx, axis=0, mode="clip")
# (sample_size, concat_dim)
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, random_key
sample_with_returns(self, random_key, sample_size)
¶
Sample transitions and the return corresponding to their episode. The returns
are compute by the method compute_returns
.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def sample_with_returns(
self,
random_key: RNGKey,
sample_size: int,
) -> Tuple[Transition, Reward, RNGKey]:
"""Sample transitions and the return corresponding to their episode. The returns
are compute by the method `compute_returns`.
Args:
random_key: a random key
sample_size: the number of transitions
Returns:
The transitions, the associated returns and a new random key.
"""
# Here we want to sample single transitions
# We sample uniformly at random the indexes of valid transitions
random_key, subkey = jax.random.split(random_key)
idx = jax.random.randint(
subkey,
shape=(sample_size,),
minval=0,
maxval=self.current_size,
)
# We select the corresponding transitions
samples = jnp.take(self.data, idx, axis=0, mode="clip")
returns = jnp.take(self.returns, idx, mode="clip")
# (sample_size, concat_dim)
transitions = self.transition.__class__.from_flatten(samples, self.transition)
return transitions, returns, random_key
insert(self, transitions)
¶
Scan over 'insert_one_transition', to add multiple transitions.
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
@jax.jit
def insert(self, transitions: Transition) -> TrajectoryBuffer:
"""
Scan over 'insert_one_transition', to add multiple transitions.
"""
@jax.jit
def insert_one_transition(
replay_buffer: TrajectoryBuffer, flattened_transitions: jnp.ndarray
) -> Tuple[TrajectoryBuffer, Any]:
"""
Insert a batch (one step over all envs) of transitions in the buffer.
"""
# Step 1: reset episodes for override
# We start by selecting the episodes that are currently being inserted
active_trajectories_indexes = (
jnp.arange(self.env_batch_size, dtype=int)
+ (replay_buffer.trajectory_positions % self.num_trajectories)
* self.env_batch_size
) % self.num_trajectories
current_episodes = replay_buffer.episodic_data[active_trajectories_indexes]
# The override condition is: "if we want to insert à timestep 0, we clear
# the corresponding row first"
condition = replay_buffer.timestep_positions % self.episode_length == 0
# Clear episodes that match the condition, don't modify others
override_episodes = jnp.where(
jnp.expand_dims(condition, -1),
jnp.ones_like(current_episodes) * jnp.nan,
current_episodes,
)
new_episodic_data = replay_buffer.episodic_data.at[
active_trajectories_indexes
].set(override_episodes)
# Step 2: insert data in main buffer and indexes in episodic buffer
# Apply changes in the episodic_data array
# Insert transitions in the buffer
new_data = jax.lax.dynamic_update_slice_in_dim(
replay_buffer.data,
flattened_transitions,
start_index=replay_buffer.current_position % self.buffer_size,
axis=0,
)
# We inserted from current_position to current_position + env_batch_size
inserted_indexes = (
jnp.arange(
self.env_batch_size,
)
+ replay_buffer.current_position
)
# We set the indexes of inserted episodes in the episodic_data array
new_episodic_data = new_episodic_data.at[
active_trajectories_indexes,
replay_buffer.timestep_positions,
].set(inserted_indexes)
# Step 3: update the counters
dones = flattened_transitions[
:, (2 * (self.transition.observation_dim) + 1)
].ravel()
# Increment the trajectory counter if done
new_trajectory_positions = replay_buffer.trajectory_positions + dones
# Within a trajectory, increment position if not done, else reset position
new_timestep_positions = jnp.where(
dones, jnp.zeros_like(dones), 1 + replay_buffer.timestep_positions
)
# Update the insertion position in the main buffer
new_current_position = (
replay_buffer.current_position + self.env_batch_size
) % self.buffer_size
# Update the size counter of the main buffer
new_current_size = jnp.minimum(
replay_buffer.current_size + self.env_batch_size, self.buffer_size
)
# Update the size of the episodic data buffer
new_current_episodic_data_size = jnp.minimum(
jnp.min(replay_buffer.trajectory_positions + 1) * self.env_batch_size,
self.num_trajectories,
)
replay_buffer = replay_buffer.replace(
timestep_positions=jnp.array(new_timestep_positions, dtype=int),
trajectory_positions=jnp.array(new_trajectory_positions, dtype=int),
data=new_data,
current_position=jnp.array(new_current_position, dtype=int),
episodic_data=new_episodic_data,
current_size=new_current_size,
current_episodic_data_size=jnp.array(
new_current_episodic_data_size, dtype=int
),
)
return replay_buffer, None
flattened_transitions = transitions.flatten()
flattened_transitions = flattened_transitions.reshape(
(-1, self.env_batch_size, flattened_transitions.shape[-1])
)
replay_buffer, _ = jax.lax.scan(
insert_one_transition,
self,
flattened_transitions,
)
replay_buffer = replay_buffer.compute_returns()
return replay_buffer # type: ignore
compute_returns(self)
¶
Computes the return for each episode in the buffer.
Returns: |
|
---|
Source code in qdax/core/neuroevolution/buffers/trajectory_buffer.py
def compute_returns(
self,
) -> TrajectoryBuffer:
"""Computes the return for each episode in the buffer.
Returns:
The buffer with the returns updated.
"""
reward_idx = 2 * self.transition.observation_dim
indexes = self.episodic_data
rewards = self.data[:, reward_idx]
episodic_returns = jnp.where(
jnp.isnan(indexes),
0,
rewards[jnp.array(jnp.nan_to_num(indexes, 0), dtype=int)],
).sum(axis=1)
values = episodic_returns[:, None].repeat(self.episode_length, axis=1)
returns = self.returns.at[
jnp.array(jnp.nan_to_num(indexes, nan=-1), dtype=int)
].set(values)
returns = returns.at[-1].set(jnp.nan)
return self.replace(returns=returns) # type: ignore
losses
special
¶
dads_loss
¶
make_dads_loss_fn(policy_fn, critic_fn, dynamics_fn, parametric_action_distribution, reward_scaling, discount, action_size, num_skills)
¶
Creates the loss used in DADS.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/dads_loss.py
def make_dads_loss_fn(
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
dynamics_fn: Callable[
[Params, StateDescriptor, Skill, StateDescriptor], jnp.ndarray
],
parametric_action_distribution: ParametricDistribution,
reward_scaling: float,
discount: float,
action_size: int,
num_skills: int,
) -> Tuple[
Callable[[jnp.ndarray, Params, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, Params, jnp.ndarray, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, Params, Params, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, QDTransition, RNGKey], jnp.ndarray],
]:
"""Creates the loss used in DADS.
Args:
policy_fn: the apply function of the policy
critic_fn: the apply function of the critic
dynamics_fn: the apply function of the dynamics network
parametric_action_distribution: the distribution over action
reward_scaling: a multiplicative factor to the reward
discount: the discount factor
action_size: the size of the environment's action space
num_skills: the number of skills set
Returns:
the loss of the entropy parameter auto-tuning
the loss of the policy
the loss of the critic
the loss of the dynamics network
"""
_alpha_loss_fn, _policy_loss_fn, _critic_loss_fn = make_sac_loss_fn(
policy_fn=policy_fn,
critic_fn=critic_fn,
reward_scaling=reward_scaling,
discount=discount,
action_size=action_size,
parametric_action_distribution=parametric_action_distribution,
)
_dynamics_loss_fn = functools.partial(
dads_dynamics_loss_fn, dynamics_fn=dynamics_fn, num_skills=num_skills
)
return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn, _dynamics_loss_fn
dads_dynamics_loss_fn(dynamics_params, dynamics_fn, num_skills, transitions)
¶
Computes the loss used to train the dynamics network.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/dads_loss.py
def dads_dynamics_loss_fn(
dynamics_params: Params,
dynamics_fn: Callable[
[Params, StateDescriptor, Skill, StateDescriptor], jnp.ndarray
],
num_skills: int,
transitions: QDTransition,
) -> jnp.ndarray:
"""Computes the loss used to train the dynamics network.
Args:
dynamics_params: the parameters of the neural network
used to predict the dynamics.
dynamics_fn: the apply function of the dynamics network
num_skills: the number of skills.
transitions: the batch of transitions used to train. They
have been sampled from a replay buffer beforehand.
Returns:
The loss obtained on the batch of transitions.
"""
active_skills = transitions.obs[:, -num_skills:]
target = transitions.next_state_desc
log_prob = dynamics_fn( # type: ignore
dynamics_params,
obs=transitions.state_desc,
skill=active_skills,
target=target,
)
# prevent training on malformed target
loss = -jnp.mean(log_prob * (1 - transitions.dones))
return loss
diayn_loss
¶
make_diayn_loss_fn(policy_fn, critic_fn, discriminator_fn, parametric_action_distribution, reward_scaling, discount, action_size, num_skills)
¶
Creates the loss used in DIAYN.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/diayn_loss.py
def make_diayn_loss_fn(
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
discriminator_fn: Callable[[Params, StateDescriptor], jnp.ndarray],
parametric_action_distribution: ParametricDistribution,
reward_scaling: float,
discount: float,
action_size: int,
num_skills: int,
) -> Tuple[
Callable[[jnp.ndarray, Params, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, Params, jnp.ndarray, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, Params, Params, QDTransition, RNGKey], jnp.ndarray],
Callable[[Params, QDTransition, RNGKey], jnp.ndarray],
]:
"""Creates the loss used in DIAYN.
Args:
policy_fn: the apply function of the policy
critic_fn: the apply function of the critic
discriminator_fn: the apply function of the discriminator
parametric_action_distribution: the distribution over actions
reward_scaling: a multiplicative factor to the reward
discount: the discount factor
action_size: the size of the environment's action space
num_skills: the number of skills set
Returns:
the loss of the entropy parameter auto-tuning
the loss of the policy
the loss of the critic
the loss of the discriminator
"""
_alpha_loss_fn, _policy_loss_fn, _critic_loss_fn = make_sac_loss_fn(
policy_fn=policy_fn,
critic_fn=critic_fn,
reward_scaling=reward_scaling,
discount=discount,
action_size=action_size,
parametric_action_distribution=parametric_action_distribution,
)
_discriminator_loss_fn = functools.partial(
diayn_discriminator_loss_fn,
discriminator_fn=discriminator_fn,
num_skills=num_skills,
)
return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn, _discriminator_loss_fn
diayn_discriminator_loss_fn(discriminator_params, discriminator_fn, num_skills, transitions)
¶
Computes the loss used to train the discriminator. The discriminator is trained to predict the skill that has been used to generate transitions. In our case, skills are one hot encoded, the discriminator is hence trained like a multi-label classifier, using the categorical cross entropy loss.
In this loss, log softmax outputs the log probabilities for each skill. By multiplying by the skills (that are one hot vectors), we filter to keep only the log probability of the true skill.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/diayn_loss.py
def diayn_discriminator_loss_fn(
discriminator_params: Params,
discriminator_fn: Callable[[Params, StateDescriptor], jnp.ndarray],
num_skills: int,
transitions: QDTransition,
) -> jnp.ndarray:
"""Computes the loss used to train the discriminator. The
discriminator is trained to predict the skill that has been
used to generate transitions. In our case, skills are one
hot encoded, the discriminator is hence trained like a
multi-label classifier, using the categorical cross entropy
loss.
In this loss, log softmax outputs the log probabilities for
each skill. By multiplying by the skills (that are one hot
vectors), we filter to keep only the log probability of the
true skill.
Args:
discriminator_params: the parameters of the neural network
used to discriminate the skills.
discriminator_fn: the apply function of the discriminator.
num_skills: the number of skills.
transitions: the transitions sampled from the replay buffer.
Returns:
The loss of the discriminator on those transitions.
"""
state_desc = transitions.state_desc
skills = transitions.obs[:, -num_skills:]
logits = jnp.sum(
jax.nn.log_softmax(discriminator_fn(discriminator_params, state_desc)) * skills,
axis=1,
)
loss = -jnp.mean(logits)
return loss
sac_loss
¶
make_sac_loss_fn(policy_fn, critic_fn, parametric_action_distribution, reward_scaling, discount, action_size)
¶
Creates the loss used in SAC.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/sac_loss.py
def make_sac_loss_fn(
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
parametric_action_distribution: ParametricDistribution,
reward_scaling: float,
discount: float,
action_size: int,
) -> Tuple[
Callable[[jnp.ndarray, Params, Transition, RNGKey], jnp.ndarray],
Callable[[Params, Params, jnp.ndarray, Transition, RNGKey], jnp.ndarray],
Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray],
]:
"""Creates the loss used in SAC.
Args:
policy_fn: the apply function of the policy
critic_fn: the apply function of the critic
parametric_action_distribution: the distribution over actions
reward_scaling: a multiplicative factor to the reward
discount: the discount factor
action_size: the size of the environment's action space
Returns:
the loss of the entropy parameter auto-tuning
the loss of the policy
the loss of the critic
"""
_policy_loss_fn = functools.partial(
sac_policy_loss_fn,
policy_fn=policy_fn,
critic_fn=critic_fn,
parametric_action_distribution=parametric_action_distribution,
)
_critic_loss_fn = functools.partial(
sac_critic_loss_fn,
policy_fn=policy_fn,
critic_fn=critic_fn,
parametric_action_distribution=parametric_action_distribution,
reward_scaling=reward_scaling,
discount=discount,
)
_alpha_loss_fn = functools.partial(
sac_alpha_loss_fn,
policy_fn=policy_fn,
parametric_action_distribution=parametric_action_distribution,
action_size=action_size,
)
return _alpha_loss_fn, _policy_loss_fn, _critic_loss_fn
sac_policy_loss_fn(policy_params, policy_fn, critic_fn, parametric_action_distribution, critic_params, alpha, transitions, random_key)
¶
Creates the policy loss used in SAC.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_policy_loss_fn(
policy_params: Params,
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
parametric_action_distribution: ParametricDistribution,
critic_params: Params,
alpha: jnp.ndarray,
transitions: Transition,
random_key: RNGKey,
) -> jnp.ndarray:
"""
Creates the policy loss used in SAC.
Args:
policy_params: parameters of the policy
policy_fn: the apply function of the policy
critic_fn: the apply function of the critic
parametric_action_distribution: the distribution over actions
critic_params: parameters of the critic
alpha: entropy coefficient value
transitions: transitions collected by the agent
random_key: random key
Returns:
the loss of the policy
"""
dist_params = policy_fn(policy_params, transitions.obs)
action = parametric_action_distribution.sample_no_postprocessing(
dist_params, random_key
)
log_prob = parametric_action_distribution.log_prob(dist_params, action)
action = parametric_action_distribution.postprocess(action)
q_action = critic_fn(critic_params, transitions.obs, action)
min_q = jnp.min(q_action, axis=-1)
actor_loss = alpha * log_prob - min_q
return jnp.mean(actor_loss)
sac_critic_loss_fn(critic_params, policy_fn, critic_fn, parametric_action_distribution, reward_scaling, discount, policy_params, target_critic_params, alpha, transitions, random_key)
¶
Creates the critic loss used in SAC.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_critic_loss_fn(
critic_params: Params,
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
parametric_action_distribution: ParametricDistribution,
reward_scaling: float,
discount: float,
policy_params: Params,
target_critic_params: Params,
alpha: jnp.ndarray,
transitions: Transition,
random_key: RNGKey,
) -> jnp.ndarray:
"""
Creates the critic loss used in SAC.
Args:
critic_params: parameters of the critic
policy_fn: the apply function of the policy
critic_fn: the apply function of the critic
parametric_action_distribution: the distribution over actions
policy_params: parameters of the policy
target_critic_params: parameters of the target critic
alpha: entropy coefficient value
transitions: transitions collected by the agent
random_key: random key
reward_scaling: a multiplicative factor to the reward
discount: the discount factor
Returns:
the loss of the critic
"""
q_old_action = critic_fn(critic_params, transitions.obs, transitions.actions)
next_dist_params = policy_fn(policy_params, transitions.next_obs)
next_action = parametric_action_distribution.sample_no_postprocessing(
next_dist_params, random_key
)
next_log_prob = parametric_action_distribution.log_prob(
next_dist_params, next_action
)
next_action = parametric_action_distribution.postprocess(next_action)
next_q = critic_fn(target_critic_params, transitions.next_obs, next_action)
next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob
target_q = jax.lax.stop_gradient(
transitions.rewards * reward_scaling
+ (1.0 - transitions.dones) * discount * next_v
)
q_error = q_old_action - jnp.expand_dims(target_q, -1)
q_error *= jnp.expand_dims(1 - transitions.truncations, -1)
q_loss = 0.5 * jnp.mean(jnp.square(q_error))
return q_loss
sac_alpha_loss_fn(log_alpha, policy_fn, parametric_action_distribution, action_size, policy_params, transitions, random_key)
¶
Creates the alpha loss used in SAC. Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/sac_loss.py
def sac_alpha_loss_fn(
log_alpha: jnp.ndarray,
policy_fn: Callable[[Params, Observation], jnp.ndarray],
parametric_action_distribution: ParametricDistribution,
action_size: int,
policy_params: Params,
transitions: Transition,
random_key: RNGKey,
) -> jnp.ndarray:
"""
Creates the alpha loss used in SAC.
Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.
Args:
log_alpha: entropy coefficient log value
policy_fn: the apply function of the policy
parametric_action_distribution: the distribution over actions
policy_params: parameters of the policy
transitions: transitions collected by the agent
random_key: random key
action_size: the size of the environment's action space
Returns:
the loss of the entropy parameter auto-tuning
"""
target_entropy = -0.5 * action_size
dist_params = policy_fn(policy_params, transitions.obs)
action = parametric_action_distribution.sample_no_postprocessing(
dist_params, random_key
)
log_prob = parametric_action_distribution.log_prob(dist_params, action)
alpha = jnp.exp(log_alpha)
alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy)
loss = jnp.mean(alpha_loss)
return loss
td3_loss
¶
Implements a function to create critic and actor losses for the TD3 algorithm.
make_td3_loss_fn(policy_fn, critic_fn, reward_scaling, discount, noise_clip, policy_noise)
¶
Creates the loss functions for TD3.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/td3_loss.py
def make_td3_loss_fn(
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
reward_scaling: float,
discount: float,
noise_clip: float,
policy_noise: float,
) -> Tuple[
Callable[[Params, Params, Transition], jnp.ndarray],
Callable[[Params, Params, Params, Transition, RNGKey], jnp.ndarray],
]:
"""Creates the loss functions for TD3.
Args:
policy_fn: forward pass through the neural network defining the policy.
critic_fn: forward pass through the neural network defining the critic.
reward_scaling: value to multiply the reward given by the environment.
discount: discount factor.
noise_clip: value that clips the noise to avoid extreme values.
policy_noise: noise applied to smooth the bootstrapping.
Returns:
Return the loss functions used to train the policy and the critic in TD3.
"""
@jax.jit
def _policy_loss_fn(
policy_params: Params,
critic_params: Params,
transitions: Transition,
) -> jnp.ndarray:
"""Policy loss function for TD3 agent"""
action = policy_fn(policy_params, transitions.obs)
q_value = critic_fn(
critic_params, obs=transitions.obs, actions=action # type: ignore
)
q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1)
policy_loss = -jnp.mean(q1_action)
return policy_loss
@jax.jit
def _critic_loss_fn(
critic_params: Params,
target_policy_params: Params,
target_critic_params: Params,
transitions: Transition,
random_key: RNGKey,
) -> jnp.ndarray:
"""Critics loss function for TD3 agent"""
noise = (
jax.random.normal(random_key, shape=transitions.actions.shape)
* policy_noise
).clip(-noise_clip, noise_clip)
next_action = (
policy_fn(target_policy_params, transitions.next_obs) + noise
).clip(-1.0, 1.0)
next_q = critic_fn( # type: ignore
target_critic_params, obs=transitions.next_obs, actions=next_action
)
next_v = jnp.min(next_q, axis=-1)
target_q = jax.lax.stop_gradient(
transitions.rewards * reward_scaling
+ (1.0 - transitions.dones) * discount * next_v
)
q_old_action = critic_fn( # type: ignore
critic_params,
obs=transitions.obs,
actions=transitions.actions,
)
q_error = q_old_action - jnp.expand_dims(target_q, -1)
# Better bootstrapping for truncated episodes.
q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1)
# compute the loss
q_losses = jnp.mean(jnp.square(q_error), axis=-2)
q_loss = jnp.sum(q_losses, axis=-1)
return q_loss
return _policy_loss_fn, _critic_loss_fn
td3_policy_loss_fn(policy_params, critic_params, policy_fn, critic_fn, transitions)
¶
Policy loss function for TD3 agent.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/td3_loss.py
def td3_policy_loss_fn(
policy_params: Params,
critic_params: Params,
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
transitions: Transition,
) -> jnp.ndarray:
"""Policy loss function for TD3 agent.
Args:
policy_params: policy parameters.
critic_params: critic parameters.
policy_fn: forward pass through the neural network defining the policy.
critic_fn: forward pass through the neural network defining the critic.
transitions: collected transitions.
Returns:
Return the loss function used to train the policy in TD3.
"""
action = policy_fn(policy_params, transitions.obs)
q_value = critic_fn(
critic_params, obs=transitions.obs, actions=action # type: ignore
)
q1_action = jnp.take(q_value, jnp.asarray([0]), axis=-1)
policy_loss = -jnp.mean(q1_action)
return policy_loss
td3_critic_loss_fn(critic_params, target_policy_params, target_critic_params, policy_fn, critic_fn, policy_noise, noise_clip, reward_scaling, discount, transitions, random_key)
¶
Critics loss function for TD3 agent.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/losses/td3_loss.py
def td3_critic_loss_fn(
critic_params: Params,
target_policy_params: Params,
target_critic_params: Params,
policy_fn: Callable[[Params, Observation], jnp.ndarray],
critic_fn: Callable[[Params, Observation, Action], jnp.ndarray],
policy_noise: float,
noise_clip: float,
reward_scaling: float,
discount: float,
transitions: Transition,
random_key: RNGKey,
) -> jnp.ndarray:
"""Critics loss function for TD3 agent.
Args:
critic_params: critic parameters.
target_policy_params: target policy parameters.
target_critic_params: target critic parameters.
policy_fn: forward pass through the neural network defining the policy.
critic_fn: forward pass through the neural network defining the critic.
policy_noise: policy noise.
noise_clip: noise clip.
reward_scaling: reward scaling coefficient.
discount: discount factor.
transitions: collected transitions.
Returns:
Return the loss function used to train the critic in TD3.
"""
noise = (
jax.random.normal(random_key, shape=transitions.actions.shape) * policy_noise
).clip(-noise_clip, noise_clip)
next_action = (policy_fn(target_policy_params, transitions.next_obs) + noise).clip(
-1.0, 1.0
)
next_q = critic_fn( # type: ignore
target_critic_params, obs=transitions.next_obs, actions=next_action
)
next_v = jnp.min(next_q, axis=-1)
target_q = jax.lax.stop_gradient(
transitions.rewards * reward_scaling
+ (1.0 - transitions.dones) * discount * next_v
)
q_old_action = critic_fn( # type: ignore
critic_params,
obs=transitions.obs,
actions=transitions.actions,
)
q_error = q_old_action - jnp.expand_dims(target_q, -1)
# Better bootstrapping for truncated episodes.
q_error = q_error * jnp.expand_dims(1.0 - transitions.truncations, -1)
# compute the loss
q_losses = jnp.mean(jnp.square(q_error), axis=-2)
q_loss = jnp.sum(q_losses, axis=-1)
return q_loss
networks
special
¶
dads_networks
¶
GaussianMixture (Module)
¶
Module that outputs a Gaussian Mixture Distribution.
Source code in qdax/core/neuroevolution/networks/dads_networks.py
class GaussianMixture(hk.Module):
"""Module that outputs a Gaussian Mixture Distribution."""
def __init__(
self,
num_dimensions: int,
num_components: int,
reinterpreted_batch_ndims: Optional[int] = None,
identity_covariance: bool = True,
initializer: Optional[Initializer] = None,
name: str = "GaussianMixture",
):
"""Module that outputs a Gaussian Mixture Distribution
with identity covariance matrix."""
super().__init__(name=name)
if initializer is None:
initializer = VarianceScaling(1.0, "fan_in", "uniform")
self._num_dimensions = num_dimensions
self._num_components = num_components
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
self._identity_covariance = identity_covariance
self.initializer = initializer
logits_size = self._num_components
self.logit_layer = hk.Linear(logits_size, w_init=self.initializer)
# Create two layers that outputs a location and a scale, respectively, for
# each dimension and each component.
self.loc_layer = hk.Linear(
self._num_dimensions * self._num_components, w_init=self.initializer
)
if not self._identity_covariance:
self.scale_layer = hk.Linear(
self._num_dimensions * self._num_components, w_init=self.initializer
)
def __call__(self, inputs: jnp.ndarray) -> tfp.distributions.Distribution:
# Compute logits, locs, and scales if necessary.
logits = self.logit_layer(inputs)
locs = self.loc_layer(inputs)
shape = [-1, self._num_components, self._num_dimensions] # [B, D, C]
# Reshape the mixture's location and scale parameters appropriately.
locs = locs.reshape(shape)
if not self._identity_covariance:
scales = self.scale_layer(inputs)
scales = scales.reshape(shape)
else:
scales = jnp.ones_like(locs)
# Create the mixture distribution
components = tfp.distributions.MultivariateNormalDiag(
loc=locs, scale_diag=scales
)
mixture = tfp.distributions.Categorical(logits=logits)
distribution = tfp.distributions.MixtureSameFamily(
mixture_distribution=mixture, components_distribution=components
)
return distribution
__init__(self, num_dimensions, num_components, reinterpreted_batch_ndims=None, identity_covariance=True, initializer=None, name='GaussianMixture')
special
¶
Module that outputs a Gaussian Mixture Distribution with identity covariance matrix.
Source code in qdax/core/neuroevolution/networks/dads_networks.py
def __init__(
self,
num_dimensions: int,
num_components: int,
reinterpreted_batch_ndims: Optional[int] = None,
identity_covariance: bool = True,
initializer: Optional[Initializer] = None,
name: str = "GaussianMixture",
):
"""Module that outputs a Gaussian Mixture Distribution
with identity covariance matrix."""
super().__init__(name=name)
if initializer is None:
initializer = VarianceScaling(1.0, "fan_in", "uniform")
self._num_dimensions = num_dimensions
self._num_components = num_components
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
self._identity_covariance = identity_covariance
self.initializer = initializer
logits_size = self._num_components
self.logit_layer = hk.Linear(logits_size, w_init=self.initializer)
# Create two layers that outputs a location and a scale, respectively, for
# each dimension and each component.
self.loc_layer = hk.Linear(
self._num_dimensions * self._num_components, w_init=self.initializer
)
if not self._identity_covariance:
self.scale_layer = hk.Linear(
self._num_dimensions * self._num_components, w_init=self.initializer
)
DynamicsNetwork (Module)
¶
Dynamics network (used in DADS).
Source code in qdax/core/neuroevolution/networks/dads_networks.py
class DynamicsNetwork(hk.Module):
"""Dynamics network (used in DADS)."""
def __init__(
self,
hidden_layer_sizes: tuple,
output_size: int,
omit_input_dynamics_dim: int = 2,
name: Optional[str] = None,
identity_covariance: bool = True,
initializer: Optional[Initializer] = None,
):
super().__init__(name=name)
if initializer is None:
initializer = VarianceScaling(1.0, "fan_in", "uniform")
self.distribution = GaussianMixture(
output_size,
num_components=4,
reinterpreted_batch_ndims=None,
identity_covariance=identity_covariance,
initializer=initializer,
)
self.network = hk.Sequential(
[
hk.nets.MLP(
list(hidden_layer_sizes),
w_init=initializer,
activation=jax.nn.relu,
activate_final=True,
),
]
)
self._omit_input_dynamics_dim = omit_input_dynamics_dim
def __call__(
self, obs: StateDescriptor, skill: Skill, target: StateDescriptor
) -> jnp.ndarray:
"""Normalizes the observation, predicts a distribution probability conditioned
on (obs,skill) and returns the log_prob of the target.
"""
obs = obs[:, self._omit_input_dynamics_dim :]
obs = jnp.concatenate((obs, skill), axis=1)
out = self.network(obs)
dist = self.distribution(out)
return dist.log_prob(target)
make_dads_networks(action_size, descriptor_size, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256), omit_input_dynamics_dim=2, identity_covariance=True, dynamics_initializer=None)
¶
Creates networks used in DADS.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/networks/dads_networks.py
def make_dads_networks(
action_size: int,
descriptor_size: int,
critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
omit_input_dynamics_dim: int = 2,
identity_covariance: bool = True,
dynamics_initializer: Optional[Initializer] = None,
) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]:
"""Creates networks used in DADS.
Args:
action_size: the size of the environment's action space
descriptor_size: the size of the environment's descriptor space (i.e. the
dimension of the dynamics network's input)
hidden_layer_sizes: the number of neurons for hidden layers.
Defaults to (256, 256).
omit_input_dynamics_dim: how many descriptors we omit when creating the input
of the dynamics networks. Defaults to 2.
identity_covariance: whether to fix the covariance matrix of the Gaussian models
to identity. Defaults to True.
dynamics_initializer: the initializer of the dynamics layers. Defaults to None.
Returns:
the policy network
the critic network
the dynamics network
"""
def _actor_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(policy_hidden_layer_size) + [2 * action_size],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
return network(obs)
def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network1 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
network2 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
input_ = jnp.concatenate([obs, action], axis=-1)
value1 = network1(input_)
value2 = network2(input_)
return jnp.concatenate([value1, value2], axis=-1)
def _dynamics_fn(
obs: StateDescriptor, skill: Skill, target: StateDescriptor
) -> jnp.ndarray:
dynamics_network = DynamicsNetwork(
critic_hidden_layer_size,
descriptor_size,
omit_input_dynamics_dim=omit_input_dynamics_dim,
identity_covariance=identity_covariance,
initializer=dynamics_initializer,
)
return dynamics_network(obs, skill, target)
policy = hk.without_apply_rng(hk.transform(_actor_fn))
critic = hk.without_apply_rng(hk.transform(_critic_fn))
dynamics = hk.without_apply_rng(hk.transform(_dynamics_fn))
return policy, critic, dynamics
diayn_networks
¶
make_diayn_networks(action_size, num_skills, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256))
¶
Creates networks used in DIAYN.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/networks/diayn_networks.py
def make_diayn_networks(
action_size: int,
num_skills: int,
critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
) -> Tuple[hk.Transformed, hk.Transformed, hk.Transformed]:
"""Creates networks used in DIAYN.
Args:
action_size: the size of the environment's action space
num_skills: the number of skills set
hidden_layer_sizes: the number of neurons for hidden layers.
Defaults to (256, 256).
Returns:
the policy network
the critic network
the discriminator network
"""
def _actor_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(policy_hidden_layer_size) + [2 * action_size],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
return network(obs)
def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network1 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
network2 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
input_ = jnp.concatenate([obs, action], axis=-1)
value1 = network1(input_)
value2 = network2(input_)
return jnp.concatenate([value1, value2], axis=-1)
def _discriminator_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [num_skills],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
return network(obs)
policy = hk.without_apply_rng(hk.transform(_actor_fn))
critic = hk.without_apply_rng(hk.transform(_critic_fn))
discriminator = hk.without_apply_rng(hk.transform(_discriminator_fn))
return policy, critic, discriminator
networks
¶
Implements neural networks models that are commonly found in the RL literature.
QModule (Module)
dataclass
¶
Q Module.
Source code in qdax/core/neuroevolution/networks/networks.py
class QModule(nn.Module):
"""Q Module."""
hidden_layer_sizes: Tuple[int, ...]
n_critics: int = 2
@nn.compact
def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
hidden = jnp.concatenate([obs, actions], axis=-1)
res = []
for _ in range(self.n_critics):
q = networks.MLP(
layer_sizes=self.hidden_layer_sizes + (1,),
activation=nn.relu,
kernel_init=jax.nn.initializers.lecun_uniform(),
)(hidden)
res.append(q)
return jnp.concatenate(res, axis=-1)
MLP (Module)
dataclass
¶
MLP module.
Source code in qdax/core/neuroevolution/networks/networks.py
class MLP(nn.Module):
"""MLP module."""
layer_sizes: Tuple[int, ...]
activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
kernel_init: Callable[..., Any] = jax.nn.initializers.lecun_uniform()
final_activation: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None
bias: bool = True
kernel_init_final: Optional[Callable[..., Any]] = None
@nn.compact
def __call__(self, data: jnp.ndarray) -> jnp.ndarray:
hidden = data
for i, hidden_size in enumerate(self.layer_sizes):
if i != len(self.layer_sizes) - 1:
hidden = nn.Dense(
hidden_size,
# name=f"hidden_{i}", with this version of flax, changing the name
# changes the initialization
kernel_init=self.kernel_init,
use_bias=self.bias,
)(hidden)
hidden = self.activation(hidden) # type: ignore
else:
if self.kernel_init_final is not None:
kernel_init = self.kernel_init_final
else:
kernel_init = self.kernel_init
hidden = nn.Dense(
hidden_size,
# name=f"hidden_{i}",
kernel_init=kernel_init,
use_bias=self.bias,
)(hidden)
if self.final_activation is not None:
hidden = self.final_activation(hidden)
return hidden
sac_networks
¶
make_sac_networks(action_size, critic_hidden_layer_size=(256, 256), policy_hidden_layer_size=(256, 256))
¶
Creates networks used in SAC.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/networks/sac_networks.py
def make_sac_networks(
action_size: int,
critic_hidden_layer_size: Tuple[int, ...] = (256, 256),
policy_hidden_layer_size: Tuple[int, ...] = (256, 256),
) -> Tuple[hk.Transformed, hk.Transformed]:
"""Creates networks used in SAC.
Args:
action_size: the size of the environment's action space
hidden_layer_sizes: the number of neurons for hidden layers.
Defaults to (256, 256).
Returns:
the policy network
the critic network
"""
def _actor_fn(obs: Observation) -> jnp.ndarray:
network = hk.Sequential(
[
hk.nets.MLP(
list(policy_hidden_layer_size) + [2 * action_size],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
return network(obs)
def _critic_fn(obs: Observation, action: Action) -> jnp.ndarray:
network1 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
network2 = hk.Sequential(
[
hk.nets.MLP(
list(critic_hidden_layer_size) + [1],
w_init=hk.initializers.VarianceScaling(1.0, "fan_in", "uniform"),
activation=jax.nn.relu,
),
]
)
input_ = jnp.concatenate([obs, action], axis=-1)
value1 = network1(input_)
value2 = network2(input_)
return jnp.concatenate([value1, value2], axis=-1)
policy = hk.without_apply_rng(hk.transform(_actor_fn))
critic = hk.without_apply_rng(hk.transform(_critic_fn))
return policy, critic
seq2seq_networks
¶
seq2seq example: Mode code.
Inspired by Flax library - https://github.com/google/flax/blob/main/examples/seq2seq/models.py
Copyright 2022 The Flax Authors. Licensed under the Apache License, Version 2.0 (the "License")
EncoderLSTM (Module)
dataclass
¶
EncoderLSTM Module wrapped in a lifted scan transform.
Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class EncoderLSTM(nn.Module):
"""EncoderLSTM Module wrapped in a lifted scan transform."""
@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=1,
out_axes=1,
split_rngs={"params": False},
)
@nn.compact
def __call__(
self, carry: Tuple[Array, Array], x: Array
) -> Tuple[Tuple[Array, Array], Array]:
"""Applies the module."""
lstm_state, is_eos = carry
features = lstm_state[0].shape[-1]
new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)
def select_carried_state(new_state: Array, old_state: Array) -> Array:
return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
# LSTM state is a tuple (c, h).
carried_lstm_state = tuple(
select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)
)
return (carried_lstm_state, is_eos), y
@staticmethod
def initialize_carry(batch_size: int, hidden_size: int) -> Tuple[Array, Array]:
# Use a dummy key since the default state init fn is just zeros.
return nn.LSTMCell(hidden_size, parent=None).initialize_carry( # type: ignore
jax.random.PRNGKey(0), (batch_size, hidden_size)
)
Encoder (Module)
dataclass
¶
LSTM encoder, returning state after finding the EOS token in the input.
Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Encoder(nn.Module):
"""LSTM encoder, returning state after finding the EOS token in the input."""
hidden_size: int
@nn.compact
def __call__(self, inputs: Array) -> Array:
batch_size = inputs.shape[0]
lstm = EncoderLSTM(name="encoder_lstm")
init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
# We use the `is_eos` array to determine whether the encoder should carry
# over the last lstm state, or apply the LSTM cell on the previous state.
init_is_eos = jnp.zeros(batch_size, dtype=bool)
init_carry = (init_lstm_state, init_is_eos)
(final_state, _), _ = lstm(init_carry, inputs)
return final_state
DecoderLSTM (Module)
dataclass
¶
DecoderLSTM Module wrapped in a lifted scan transform.
Attributes:
Name | Type | Description |
---|---|---|
teacher_force |
bool |
See docstring on Seq2seq module. |
obs_size |
int |
Size of the observations. |
Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class DecoderLSTM(nn.Module):
"""DecoderLSTM Module wrapped in a lifted scan transform.
Attributes:
teacher_force: See docstring on Seq2seq module.
obs_size: Size of the observations.
"""
teacher_force: bool
obs_size: int
@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=1,
out_axes=1,
split_rngs={"params": False, "lstm": True},
)
@nn.compact
def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
"""Applies the DecoderLSTM model."""
lstm_state, last_prediction = carry
if not self.teacher_force:
x = last_prediction
features = lstm_state[0].shape[-1]
new_lstm_state, y = nn.LSTMCell(features)(lstm_state, x)
logits = nn.Dense(features=self.obs_size)(y)
return (lstm_state, logits), (logits, logits)
Decoder (Module)
dataclass
¶
LSTM decoder.
Attributes:
Name | Type | Description |
---|---|---|
init_state |
[batch_size, hidden_size] Initial state of the decoder (i.e., the final state of the encoder). |
|
teacher_force |
bool |
See docstring on Seq2seq module. |
obs_size |
int |
Size of the observations. |
Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Decoder(nn.Module):
"""LSTM decoder.
Attributes:
init_state: [batch_size, hidden_size]
Initial state of the decoder (i.e., the final state of the encoder).
teacher_force: See docstring on Seq2seq module.
obs_size: Size of the observations.
"""
teacher_force: bool
obs_size: int
@nn.compact
def __call__(self, inputs: Array, init_state: Any) -> Tuple[Array, Array]:
"""Applies the decoder model.
Args:
inputs: [batch_size, max_output_len-1, obs_size]
Contains the inputs to the decoder at each time step (only used when not
using teacher forcing). Since each token at position i is fed as input
to the decoder at position i+1, the last token is not provided.
Returns:
Pair (logits, predictions), which are two arrays of respectively decoded
logits and predictions (in one hot-encoding format).
"""
lstm = DecoderLSTM(teacher_force=self.teacher_force, obs_size=self.obs_size)
init_carry = (init_state, inputs[:, 0])
_, (logits, predictions) = lstm(init_carry, inputs)
return logits, predictions
Seq2seq (Module)
dataclass
¶
Sequence-to-sequence class using encoder/decoder architecture.
Attributes:
Name | Type | Description |
---|---|---|
teacher_force |
bool |
whether to use |
hidden_size |
int |
int, the number of hidden dimensions in the encoder and decoder LSTMs. |
obs_size |
int |
the size of the observations. |
eos_id |
EOS id. |
Source code in qdax/core/neuroevolution/networks/seq2seq_networks.py
class Seq2seq(nn.Module):
"""Sequence-to-sequence class using encoder/decoder architecture.
Attributes:
teacher_force: whether to use `decoder_inputs` as input to the decoder at
every step. If False, only the first input (i.e., the "=" token) is used,
followed by samples taken from the previous output logits.
hidden_size: int, the number of hidden dimensions in the encoder and decoder
LSTMs.
obs_size: the size of the observations.
eos_id: EOS id.
"""
teacher_force: bool
hidden_size: int
obs_size: int
def setup(self) -> None:
self.encoder = Encoder(hidden_size=self.hidden_size)
self.decoder = Decoder(teacher_force=self.teacher_force, obs_size=self.obs_size)
@nn.compact
def __call__(
self, encoder_inputs: Array, decoder_inputs: Array
) -> Tuple[Array, Array]:
"""Applies the seq2seq model.
Args:
encoder_inputs: [batch_size, max_input_length, obs_size].
padded batch of input sequences to encode.
decoder_inputs: [batch_size, max_output_length, obs_size].
padded batch of expected decoded sequences for teacher forcing.
When sampling (i.e., `teacher_force = False`), only the first token is
input into the decoder (which is the token "="), and samples are used
for the following inputs. The second dimension of this tensor determines
how many steps will be decoded, regardless of the value of
`teacher_force`.
Returns:
Pair (logits, predictions), which are two arrays of length `batch_size`
containing respectively decoded logits and predictions (in one hot
encoding format).
"""
# encode inputs
init_decoder_state = self.encoder(encoder_inputs)
# decode outputs
logits, predictions = self.decoder(decoder_inputs, init_decoder_state)
return logits, predictions
def encode(self, encoder_inputs: Array) -> Array:
# encode inputs
init_decoder_state = self.encoder(encoder_inputs)
final_output, _hidden_state = init_decoder_state
return final_output
td3_networks
¶
Implements a function to create neural networks for the TD3 algorithm.
make_td3_networks(action_size, critic_hidden_layer_sizes, policy_hidden_layer_sizes)
¶
Creates networks used by the TD3 agent.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/networks/td3_networks.py
def make_td3_networks(
action_size: int,
critic_hidden_layer_sizes: Tuple[int, ...],
policy_hidden_layer_sizes: Tuple[int, ...],
) -> Tuple[MLP, QModule]:
"""Creates networks used by the TD3 agent.
Args:
action_size: Size the action array used to interact with the environment.
critic_hidden_layer_sizes: Number of layers and units per layer used in the
neural network defining the critic.
policy_hidden_layer_sizes: Number of layers and units per layer used in the
neural network defining the policy.
Returns:
The neural network defining the policy and the module defininf the critic.
This module contains two neural networks.
"""
# Instantiate policy and critics networks
policy_layer_sizes = policy_hidden_layer_sizes + (action_size,)
policy_network = MLP(
layer_sizes=policy_layer_sizes,
final_activation=jnp.tanh,
)
q_network = QModule(n_critics=2, hidden_layer_sizes=critic_hidden_layer_sizes)
return (policy_network, q_network)
mdp_utils
¶
TrainingState (PyTreeNode)
dataclass
¶
The state of a training process. Can be used to store anything that is useful for a training process. This object is used in the package to store all stateful object necessary for training an agent that learns how to act in an MDP.
Source code in qdax/core/neuroevolution/mdp_utils.py
class TrainingState(PyTreeNode):
"""The state of a training process. Can be used to store anything
that is useful for a training process. This object is used in the
package to store all stateful object necessary for training an agent
that learns how to act in an MDP.
"""
pass
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/core/neuroevolution/mdp_utils.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
generate_unroll(init_state, policy_params, random_key, episode_length, play_step_fn)
¶
Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/mdp_utils.py
@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
init_state: EnvState,
policy_params: Params,
random_key: RNGKey,
episode_length: int,
play_step_fn: Callable[
[EnvState, Params, RNGKey],
Tuple[
EnvState,
Params,
RNGKey,
Transition,
],
],
) -> Tuple[EnvState, Transition]:
"""Generates an episode according to the agent's policy, returns the final state of
the episode and the transitions of the episode.
Args:
init_state: first state of the rollout.
policy_params: params of the individual.
random_key: random key for stochasiticity handling.
episode_length: length of the rollout.
play_step_fn: function describing how a step need to be taken.
Returns:
A new state, the experienced transition.
"""
def _scan_play_step_fn(
carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any
) -> Tuple[Tuple[EnvState, Params, RNGKey], Transition]:
env_state, policy_params, random_key, transitions = play_step_fn(*carry)
return (env_state, policy_params, random_key), transitions
(state, _, _), transitions = jax.lax.scan(
_scan_play_step_fn,
(init_state, policy_params, random_key),
(),
length=episode_length,
)
return state, transitions
get_first_episode(transition)
¶
Extracts the first episode from a batch of transitions, returns the batch of transitions that is masked with nans except for the first episode.
Source code in qdax/core/neuroevolution/mdp_utils.py
@jax.jit
def get_first_episode(transition: Transition) -> Transition:
"""Extracts the first episode from a batch of transitions, returns the batch of
transitions that is masked with nans except for the first episode.
"""
dones = jnp.roll(transition.dones, 1, axis=0).at[0].set(0)
mask = 1 - jnp.clip(jnp.cumsum(dones, axis=0), 0, 1)
def mask_episodes(x: jnp.ndarray) -> jnp.ndarray:
# the double transpose trick is here to allow easy broadcasting
return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T
return jax.tree_map(mask_episodes, transition) # type: ignore
init_population_controllers(policy_network, env, batch_size, random_key)
¶
Initializes the population of controllers using a policy_network.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/core/neuroevolution/mdp_utils.py
def init_population_controllers(
policy_network: nn.Module,
env: brax.envs.Env,
batch_size: int,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Initializes the population of controllers using a policy_network.
Args:
policy_network: The policy network structure used for creating policy
controllers.
env: the BRAX environment.
batch_size: the number of environments we play simultaneously.
random_key: a JAX random key.
Returns:
A tuple of the initial population and the new random key.
"""
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)
return init_variables, random_key
normalization_utils
¶
Utilities functions to perform normalization (generally on observations in RL).
RunningMeanStdState (tuple)
¶
Running statistics for observtations/rewards
Source code in qdax/core/neuroevolution/normalization_utils.py
class RunningMeanStdState(NamedTuple):
"""Running statistics for observtations/rewards"""
mean: jnp.ndarray
var: jnp.ndarray
count: jnp.ndarray
update_running_mean_std(std_state, obs)
¶
Update running statistics with batch of observations (Welford's algorithm)
Source code in qdax/core/neuroevolution/normalization_utils.py
def update_running_mean_std(
std_state: RunningMeanStdState, obs: Observation
) -> RunningMeanStdState:
"""Update running statistics with batch of observations (Welford's algorithm)"""
running_mean, running_variance, normalization_steps = std_state
step_increment = obs.shape[0]
total_new_steps = normalization_steps + step_increment
# Compute the incremental update and divide by the number of new steps.
input_to_old_mean = obs - running_mean
mean_diff = jnp.sum(input_to_old_mean / total_new_steps, axis=0)
new_mean = running_mean + mean_diff
# Compute difference of input to the new mean for Welford update.
input_to_new_mean = obs - new_mean
var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=0)
return RunningMeanStdState(new_mean, running_variance + var_diff, total_new_steps)
normalize_with_rmstd(obs, rmstd, std_min_value=1e-06, std_max_value=1000000.0, apply_clipping=True)
¶
Normalize input with provided running statistics
Source code in qdax/core/neuroevolution/normalization_utils.py
def normalize_with_rmstd(
obs: jnp.ndarray,
rmstd: RunningMeanStdState,
std_min_value: float = 1e-6,
std_max_value: float = 1e6,
apply_clipping: bool = True,
) -> jnp.ndarray:
"""Normalize input with provided running statistics"""
running_mean, running_variance, normalization_steps = rmstd
variance = running_variance / (normalization_steps + 1.0)
# We clip because the running_variance can become negative,
# presumably because of numerical instabilities.
if apply_clipping:
variance = jnp.clip(variance, std_min_value, std_max_value)
return jnp.clip((obs - running_mean) / jnp.sqrt(variance), -5, 5)
else:
return (obs - running_mean) / jnp.sqrt(variance)
sac_td3_utils
¶
Functions similar to the ones in mdp_utils, the main difference is the assumption that the policy parameters are included in the training state. By passing this whole training state we can update running statistics for normalization for instance.
We are currently thinking about elegant ways to unify both in order to avoid code repetition.
warmstart_buffer(replay_buffer, training_state, env_state, play_step_fn, num_warmstart_steps, env_batch_size)
¶
Pre-populates the buffer with transitions. Returns the warmstarted buffer and the new state of the environment.
Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(
jax.jit,
static_argnames=(
"num_warmstart_steps",
"play_step_fn",
"env_batch_size",
),
)
def warmstart_buffer(
replay_buffer: ReplayBuffer,
training_state: TrainingState,
env_state: EnvState,
play_step_fn: Callable[
[EnvState, TrainingState],
Tuple[
EnvState,
TrainingState,
Transition,
],
],
num_warmstart_steps: int,
env_batch_size: int,
) -> Tuple[ReplayBuffer, EnvState, TrainingState]:
"""Pre-populates the buffer with transitions. Returns the warmstarted buffer
and the new state of the environment.
"""
def _scan_play_step_fn(
carry: Tuple[EnvState, TrainingState], unused_arg: Any
) -> Tuple[Tuple[EnvState, TrainingState], Transition]:
env_state, training_state, transitions = play_step_fn(*carry)
return (env_state, training_state), transitions
(env_state, training_state), transitions = jax.lax.scan(
_scan_play_step_fn,
(env_state, training_state),
(),
length=num_warmstart_steps // env_batch_size,
)
replay_buffer = replay_buffer.insert(transitions)
return replay_buffer, env_state, training_state
generate_unroll(init_state, training_state, episode_length, play_step_fn)
¶
Generates an episode according to the agent's policy, returns the final state of the episode and the transitions of the episode.
Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
init_state: EnvState,
training_state: TrainingState,
episode_length: int,
play_step_fn: Callable[
[EnvState, TrainingState],
Tuple[
EnvState,
TrainingState,
Transition,
],
],
) -> Tuple[EnvState, TrainingState, Transition]:
"""Generates an episode according to the agent's policy, returns the final state of the
episode and the transitions of the episode.
"""
def _scan_play_step_fn(
carry: Tuple[EnvState, TrainingState], unused_arg: Any
) -> Tuple[Tuple[EnvState, TrainingState], Transition]:
env_state, training_state, transitions = play_step_fn(*carry)
return (env_state, training_state), transitions
(state, training_state), transitions = jax.lax.scan(
_scan_play_step_fn,
(init_state, training_state),
(),
length=episode_length,
)
return state, training_state, transitions
do_iteration_fn(training_state, env_state, replay_buffer, env_batch_size, grad_updates_per_step, play_step_fn, update_fn)
¶
Performs one environment step (over all env simultaneously) followed by one
training step. The number of updates is controlled by the parameter
grad_updates_per_step
(0 means no update while 1 means env_batch_size
updates). Returns the updated states, the updated buffer and the aggregated
metrics.
Source code in qdax/core/neuroevolution/sac_td3_utils.py
@partial(
jax.jit,
static_argnames=(
"env_batch_size",
"grad_updates_per_step",
"play_step_fn",
"update_fn",
),
)
def do_iteration_fn(
training_state: TrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
env_batch_size: int,
grad_updates_per_step: float,
play_step_fn: Callable[
[EnvState, TrainingState],
Tuple[
EnvState,
TrainingState,
Transition,
],
],
update_fn: Callable[
[TrainingState, ReplayBuffer],
Tuple[
TrainingState,
ReplayBuffer,
Metrics,
],
],
) -> Tuple[TrainingState, EnvState, ReplayBuffer, Metrics]:
"""Performs one environment step (over all env simultaneously) followed by one
training step. The number of updates is controlled by the parameter
`grad_updates_per_step` (0 means no update while 1 means `env_batch_size`
updates). Returns the updated states, the updated buffer and the aggregated
metrics.
"""
def _scan_update_fn(
carry: Tuple[TrainingState, ReplayBuffer], unused_arg: Any
) -> Tuple[Tuple[TrainingState, ReplayBuffer], Metrics]:
training_state, replay_buffer, metrics = update_fn(*carry)
return (training_state, replay_buffer), metrics
# play steps in the environment
env_state, training_state, transitions = play_step_fn(env_state, training_state)
# insert transitions in replay buffer
replay_buffer = replay_buffer.insert(transitions)
num_updates = int(grad_updates_per_step * env_batch_size)
(training_state, replay_buffer), metrics = jax.lax.scan(
_scan_update_fn,
(training_state, replay_buffer),
(),
length=num_updates,
)
return training_state, env_state, replay_buffer, metrics