Environments¶
qdax.environments
special
¶
create(env_name, episode_length=1000, action_repeat=1, auto_reset=True, batch_size=None, eval_metrics=False, fixed_init_state=False, qdax_wrappers_kwargs=None, **kwargs)
¶
Creates an Env with a specified brax system. Please use namespace to avoid confusion between this function and brax.envs.create.
Source code in qdax/environments/__init__.py
def create(
env_name: str,
episode_length: int = 1000,
action_repeat: int = 1,
auto_reset: bool = True,
batch_size: Optional[int] = None,
eval_metrics: bool = False,
fixed_init_state: bool = False,
qdax_wrappers_kwargs: Optional[List] = None,
**kwargs: Any,
) -> Union[Env, QDEnv]:
"""Creates an Env with a specified brax system.
Please use namespace to avoid confusion between this function and
brax.envs.create.
"""
if env_name in _envs.keys():
env = _envs[env_name](legacy_spring=True, **kwargs)
elif env_name in _qdax_envs.keys():
env = _qdax_envs[env_name](**kwargs)
elif env_name in _qdax_custom_envs.keys():
base_env_name = _qdax_custom_envs[env_name]["env"]
if base_env_name in _envs.keys():
env = _envs[base_env_name](legacy_spring=True, **kwargs)
elif base_env_name in _qdax_envs.keys():
env = _qdax_envs[base_env_name](**kwargs) # type: ignore
else:
raise NotImplementedError("This environment name does not exist!")
if env_name in _qdax_custom_envs.keys():
# roll with qdax wrappers
wrappers = _qdax_custom_envs[env_name]["wrappers"]
if qdax_wrappers_kwargs is None:
kwargs_list = _qdax_custom_envs[env_name]["kwargs"]
else:
kwargs_list = qdax_wrappers_kwargs
for wrapper, kwargs in zip(wrappers, kwargs_list): # type: ignore
env = wrapper(env, base_env_name, **kwargs) # type: ignore
if episode_length is not None:
env = EpisodeWrapper(env, episode_length, action_repeat)
if batch_size:
env = VectorWrapper(env, batch_size)
if fixed_init_state:
# retrieve the base env
if env_name not in _qdax_custom_envs.keys():
base_env_name = env_name
# wrap the env
env = FixedInitialStateWrapper(env, base_env_name=base_env_name) # type: ignore
if auto_reset:
env = AutoResetWrapper(env)
if env_name in _qdax_custom_envs.keys():
env = StateDescriptorResetWrapper(env)
if eval_metrics:
env = EvalWrapper(env)
env = CompletedEvalWrapper(env)
return env
create_fn(env_name, **kwargs)
¶
Returns a function that when called, creates an Env. Please use namespace to avoid confusion between this function and brax.envs.create_fn.
Source code in qdax/environments/__init__.py
def create_fn(env_name: str, **kwargs: Any) -> Callable[..., Env]:
"""Returns a function that when called, creates an Env.
Please use namespace to avoid confusion between this function and
brax.envs.create_fn.
"""
return functools.partial(create, env_name, **kwargs)
base_wrappers
¶
QDEnv (Env)
¶
Wrapper for all QD environments.
Source code in qdax/environments/base_wrappers.py
class QDEnv(Env):
"""
Wrapper for all QD environments.
"""
@property
@abstractmethod
def state_descriptor_length(self) -> int:
pass
@property
@abstractmethod
def state_descriptor_name(self) -> str:
pass
@property
@abstractmethod
def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
pass
@property
@abstractmethod
def behavior_descriptor_length(self) -> int:
pass
@property
@abstractmethod
def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
pass
@property
@abstractmethod
def name(self) -> str:
pass
QDWrapper (QDEnv)
¶
Wrapper for QD environments.
Source code in qdax/environments/base_wrappers.py
class QDWrapper(QDEnv):
"""Wrapper for QD environments."""
def __init__(self, env: QDEnv):
super().__init__(config=None)
self.env = env
def reset(self, rng: jp.ndarray) -> State:
return self.env.reset(rng)
def step(self, state: State, action: jp.ndarray) -> State:
return self.env.step(state, action)
@property
def observation_size(self) -> int:
return self.env.observation_size # type: ignore
@property
def action_size(self) -> int:
return self.env.action_size # type: ignore
@property
def state_descriptor_length(self) -> int:
return self.env.state_descriptor_length
@property
def state_descriptor_name(self) -> str:
return self.env.state_descriptor_name
@property
def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return self.env.state_descriptor_limits
@property
def behavior_descriptor_length(self) -> int:
return self.env.behavior_descriptor_length
@property
def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return self.env.behavior_descriptor_limits
@property
def name(self) -> str:
return self.env.name
@property
def unwrapped(self) -> Env:
return self.env.unwrapped
def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self.env, name)
observation_size: int
property
readonly
¶
The size of the observation vector returned in step and reset.
action_size: int
property
readonly
¶
The size of the action vector expected by step.
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/base_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
return self.env.reset(rng)
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/base_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
return self.env.step(state, action)
StateDescriptorResetWrapper (QDWrapper)
¶
Automatically resets state descriptors.
Source code in qdax/environments/base_wrappers.py
class StateDescriptorResetWrapper(QDWrapper):
"""Automatically resets state descriptors."""
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["first_state_descriptor"] = state.info["state_descriptor"]
return state
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
def where_done(x: jp.ndarray, y: jp.ndarray) -> jp.ndarray:
done = state.done
if done.shape:
done = jp.reshape(done, tuple([x.shape[0]] + [1] * (len(x.shape) - 1)))
return jp.where(done, x, y)
state.info["state_descriptor"] = where_done(
state.info["first_state_descriptor"], state.info["state_descriptor"]
)
return state
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/base_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["first_state_descriptor"] = state.info["state_descriptor"]
return state
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/base_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
def where_done(x: jp.ndarray, y: jp.ndarray) -> jp.ndarray:
done = state.done
if done.shape:
done = jp.reshape(done, tuple([x.shape[0]] + [1] * (len(x.shape) - 1)))
return jp.where(done, x, y)
state.info["state_descriptor"] = where_done(
state.info["first_state_descriptor"], state.info["state_descriptor"]
)
return state
bd_extractors
¶
AuroraExtraInfo (PyTreeNode)
dataclass
¶
Information specific to the AURORA algorithm.
Parameters: |
|
---|
Source code in qdax/environments/bd_extractors.py
class AuroraExtraInfo(flax.struct.PyTreeNode):
"""
Information specific to the AURORA algorithm.
Args:
model_params: the parameters of the dimensionality reduction model
"""
model_params: Params
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/environments/bd_extractors.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
AuroraExtraInfoNormalization (AuroraExtraInfo)
dataclass
¶
Information specific to the AURORA algorithm. In particular, it contains the normalization parameters for the observations.
Parameters: |
|
---|
Source code in qdax/environments/bd_extractors.py
class AuroraExtraInfoNormalization(AuroraExtraInfo):
"""
Information specific to the AURORA algorithm. In particular, it contains
the normalization parameters for the observations.
Args:
model_params: the parameters of the dimensionality reduction model
mean_observations: the mean of observations
std_observations: the std of observations
"""
mean_observations: jnp.ndarray
std_observations: jnp.ndarray
@classmethod
def create(
cls,
model_params: Params,
mean_observations: jnp.ndarray,
std_observations: jnp.ndarray,
) -> AuroraExtraInfoNormalization:
return cls(
model_params=model_params,
mean_observations=mean_observations,
std_observations=std_observations,
)
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/environments/bd_extractors.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
get_final_xy_position(data, mask)
¶
Compute final xy positon.
This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given.
Source code in qdax/environments/bd_extractors.py
def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor:
"""Compute final xy positon.
This function suppose that state descriptor is the xy position, as it
just select the final one of the state descriptors given.
"""
# reshape mask for bd extraction
mask = jnp.expand_dims(mask, axis=-1)
# Get behavior descriptor
last_index = jnp.int32(jnp.sum(1.0 - mask, axis=1)) - 1
descriptors = jax.vmap(lambda x, y: x[y])(data.state_desc, last_index)
# remove the dim coming from the trajectory
return descriptors.squeeze(axis=1)
get_feet_contact_proportion(data, mask)
¶
Compute feet contact time proportion.
This function suppose that state descriptor is the feet contact, as it just computes the mean of the state descriptors given.
Source code in qdax/environments/bd_extractors.py
def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descriptor:
"""Compute feet contact time proportion.
This function suppose that state descriptor is the feet contact, as it
just computes the mean of the state descriptors given.
"""
# reshape mask for bd extraction
mask = jnp.expand_dims(mask, axis=-1)
# Get behavior descriptor
descriptors = jnp.sum(data.state_desc * (1.0 - mask), axis=1)
descriptors = descriptors / jnp.sum(1.0 - mask, axis=1)
return descriptors
get_aurora_encoding(observations, aurora_extra_info, model)
¶
Compute final aurora embedding.
This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given.
Source code in qdax/environments/bd_extractors.py
def get_aurora_encoding(
observations: jnp.ndarray,
aurora_extra_info: AuroraExtraInfoNormalization,
model: flax.linen.Module,
) -> Descriptor:
"""
Compute final aurora embedding.
This function suppose that state descriptor is the xy position, as it
just select the final one of the state descriptors given.
"""
model_params = aurora_extra_info.model_params
mean_observations = aurora_extra_info.mean_observations
std_observations = aurora_extra_info.std_observations
# lstm seq2seq
normalized_observations = (observations - mean_observations) / std_observations
descriptors = model.apply(
{"params": model_params}, normalized_observations, method=model.encode
)
return descriptors.squeeze()
exploration_wrappers
¶
TrapWrapper (Wrapper)
¶
Wraps gym environments to add a Trap in the environment.
Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the Trap to the environment.
This wrapper also adds xy in the observation, as it is an important information for an agent. Now that there is a trap in its env, we expect its actions to depend on its xy position.
The xy position is normalised thanks to the decided limits of the env, which are [0, 30] for x and [-8, 8] for y.
The only supported envs at the moment are among the classic locomotion envs : Ant.
RMQ: Humanoid is not supported yet. RMQ: works for walker2d etc.. but it does not make sens as they can only go in one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = TrapWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
Source code in qdax/environments/exploration_wrappers.py
class TrapWrapper(Wrapper):
"""Wraps gym environments to add a Trap in the environment.
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply add the Trap to the environment.
This wrapper also adds xy in the observation, as it is an important
information for an agent. Now that there is a trap in its env, we
expect its actions to depend on its xy position.
The xy position is normalised thanks to the decided limits of the env,
which are [0, 30] for x and [-8, 8] for y.
The only supported envs at the moment are among the classic
locomotion envs : Ant.
RMQ: Humanoid is not supported yet.
RMQ: works for walker2d etc.. but it does not make sens as they
can only go in one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = TrapWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
"""
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
or env_name not in ENV_TRAP_COLLISION.keys()
):
raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
if env_name not in ["ant", "humanoid"]:
warnings.warn(
"Make sure your agent can move in two dimensions!",
stacklevel=2,
)
super().__init__(env)
self._env_name = env_name
# update the env config to add the trap
self._config = (
ENV_SYSTEM_CONFIG[env_name] + TRAP_CONFIG + ENV_TRAP_COLLISION[env_name]
)
# update the associated physical system
config = text_format.Parse(self._config, brax.Config())
if not hasattr(self.unwrapped, "sys"):
raise AttributeError("Cannot link env to a physical system.")
self.unwrapped.sys = brax.System(config)
self._cog_idx = self.unwrapped.sys.body.index[COG_NAMES[env_name]]
# we need to normalise x/y position to avoid values to explose
self._substract = jnp.array([15, 0]) # come from env limits
self._divide = jnp.array([15, 8]) # come from env limits
@property
def name(self) -> str:
return self._env_name
@property
def observation_size(self) -> int:
"""The size of the observation vector returned in step and reset."""
rng = jp.random_prngkey(0)
reset_state = self.reset(rng)
return int(reset_state.obs.shape[-1])
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
# add xy position to the observation
xy_pos = state.qp.pos[self._cog_idx][:2]
# normalise
xy_pos = (xy_pos - self._substract) / self._divide
new_obs = jp.concatenate([xy_pos, state.obs])
return state.replace(obs=new_obs) # type: ignore
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# add xy position to the observation
xy_pos = state.qp.pos[self._cog_idx][:2]
# normalise
xy_pos = (xy_pos - self._substract) / self._divide
new_obs = jp.concatenate([xy_pos, state.obs])
return state.replace(obs=new_obs) # type: ignore
observation_size: int
property
readonly
¶
The size of the observation vector returned in step and reset.
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/exploration_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
# add xy position to the observation
xy_pos = state.qp.pos[self._cog_idx][:2]
# normalise
xy_pos = (xy_pos - self._substract) / self._divide
new_obs = jp.concatenate([xy_pos, state.obs])
return state.replace(obs=new_obs) # type: ignore
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/exploration_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# add xy position to the observation
xy_pos = state.qp.pos[self._cog_idx][:2]
# normalise
xy_pos = (xy_pos - self._substract) / self._divide
new_obs = jp.concatenate([xy_pos, state.obs])
return state.replace(obs=new_obs) # type: ignore
MazeWrapper (Wrapper)
¶
Wraps gym environments to add a maze in the environment and a new reward (distance to the goal).
Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the Maze to the environment, along with the new reward.
This wrapper also adds xy in the observation, as it is an important information for an agent. Now that the agent is in a maze, we expect its actions to depend on its xy position.
The xy position is normalised thanks to the decided limits of the env, which are [-5, 40] for x and y.
The only supported envs at the moment are among the classic locomotion envs : Ant.
RMQ: Humanoid is not supported yet. RMQ: works for walker2d etc.. but it does not make sens as they can only go in one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = MazeWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
Source code in qdax/environments/exploration_wrappers.py
class MazeWrapper(Wrapper):
"""Wraps gym environments to add a maze in the environment
and a new reward (distance to the goal).
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply add the Maze to the environment,
along with the new reward.
This wrapper also adds xy in the observation, as it is an important
information for an agent. Now that the agent is in a maze, we
expect its actions to depend on its xy position.
The xy position is normalised thanks to the decided limits of the env,
which are [-5, 40] for x and y.
The only supported envs at the moment are among the classic
locomotion envs : Ant.
RMQ: Humanoid is not supported yet.
RMQ: works for walker2d etc.. but it does not make sens as they
can only go in one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = MazeWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
"""
def __init__(self, env: Env, env_name: str) -> None:
if (
env_name not in ENV_SYSTEM_CONFIG.keys()
or env_name not in COG_NAMES.keys()
or env_name not in ENV_MAZE_COLLISION.keys()
):
raise NotImplementedError(
f"This wrapper does not support {env_name} yet.",
)
if env_name not in ["ant", "humanoid"]:
warnings.warn(
"Make sure your agent can move in two dimensions!",
stacklevel=2,
)
super().__init__(env)
self._env_name = env_name
self._config = (
ENV_SYSTEM_CONFIG[env_name] + MAZE_CONFIG + ENV_MAZE_COLLISION[env_name]
)
config = text_format.Parse(self._config, brax.Config())
if not hasattr(self.unwrapped, "sys"):
raise AttributeError("Cannot link env to a physical system.")
self.unwrapped.sys = brax.System(config)
self._cog_idx = self.unwrapped.sys.body.index[COG_NAMES[env_name]]
self._target_idx = self.unwrapped.sys.body.index["Target"]
# we need to normalise x/y position to avoid values to explose
self._substract = jnp.array([17.5, 17.5]) # come from env limits
self._divide = jnp.array([22.5, 22.5]) # come from env limits
@property
def name(self) -> str:
return self._env_name
@property
def observation_size(self) -> int:
"""The size of the observation vector returned in step and reset."""
rng = jp.random_prngkey(0)
reset_state = self.reset(rng)
return int(reset_state.obs.shape[-1])
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
# get xy position of the center of gravity and of the target
cog_xy_position = state.qp.pos[self._cog_idx][:2]
target_xy_position = state.qp.pos[self._target_idx][:2]
# update the reward
new_reward = -jp.norm(target_xy_position - cog_xy_position)
# add cog xy position to the observation - normalise
cog_xy_position = (cog_xy_position - self._substract) / self._divide
new_obs = jp.concatenate([cog_xy_position, state.obs])
return state.replace(obs=new_obs, reward=new_reward) # type: ignore
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# get xy position of the center of gravity and of the target
cog_xy_position = state.qp.pos[self._cog_idx][:2]
target_xy_position = state.qp.pos[self._target_idx][:2]
# update the reward
new_reward = -jp.norm(target_xy_position - cog_xy_position)
# add cog xy position to the observation - normalise
cog_xy_position = (cog_xy_position - self._substract) / self._divide
new_obs = jp.concatenate([cog_xy_position, state.obs])
# brax ant suicides by jumping over a manually designed z threshold
# this line avoid this by increasing the threshold
done = jp.where(
state.qp.pos[0, 2] < 0.2,
x=jp.array(1, dtype=jp.float32),
y=jp.array(0, dtype=jp.float32),
)
done = jp.where(
state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done
)
return state.replace(obs=new_obs, reward=new_reward, done=done) # type: ignore
observation_size: int
property
readonly
¶
The size of the observation vector returned in step and reset.
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/exploration_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
# get xy position of the center of gravity and of the target
cog_xy_position = state.qp.pos[self._cog_idx][:2]
target_xy_position = state.qp.pos[self._target_idx][:2]
# update the reward
new_reward = -jp.norm(target_xy_position - cog_xy_position)
# add cog xy position to the observation - normalise
cog_xy_position = (cog_xy_position - self._substract) / self._divide
new_obs = jp.concatenate([cog_xy_position, state.obs])
return state.replace(obs=new_obs, reward=new_reward) # type: ignore
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/exploration_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# get xy position of the center of gravity and of the target
cog_xy_position = state.qp.pos[self._cog_idx][:2]
target_xy_position = state.qp.pos[self._target_idx][:2]
# update the reward
new_reward = -jp.norm(target_xy_position - cog_xy_position)
# add cog xy position to the observation - normalise
cog_xy_position = (cog_xy_position - self._substract) / self._divide
new_obs = jp.concatenate([cog_xy_position, state.obs])
# brax ant suicides by jumping over a manually designed z threshold
# this line avoid this by increasing the threshold
done = jp.where(
state.qp.pos[0, 2] < 0.2,
x=jp.array(1, dtype=jp.float32),
y=jp.array(0, dtype=jp.float32),
)
done = jp.where(
state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done
)
return state.replace(obs=new_obs, reward=new_reward, done=done) # type: ignore
humanoidtrap
¶
Trains a humanoid to run in the +x direction. Added a Trap in front. Highly inspired from brax humanoid.
HumanoidTrap (Env)
¶
Trains a humanoid to run in the +x direction.
RMQ: uses legacy spring from Brax.
Source code in qdax/environments/humanoidtrap.py
class HumanoidTrap(Env):
"""Trains a humanoid to run in the +x direction.
RMQ: uses legacy spring from Brax.
"""
def __init__(self, **kwargs: Dict[str, Any]) -> None:
config = _SYSTEM_CONFIG
# change compared to humanoid
config = config + TRAP_CONFIG + HUMANOID_TRAP_COLLISIONS
super().__init__(config=config, **kwargs)
body = bodies.Body(self.sys.config)
# change compare to humanoid
body = jp.take(body, body.idx[:-2]) # skip the floor body & trap body
self.mass = body.mass.reshape(-1, 1)
self.inertia = body.inertia
self.inertia_matrix = jp.array([jp.diag(a) for a in self.inertia])
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
qpos = self.sys.default_angle() + jp.random_uniform(
rng1, (self.sys.num_joint_dof,), -0.01, 0.01
)
qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.01, 0.01)
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
info = self.sys.info(qp)
obs = self._get_obs(qp, info, jp.zeros(self.action_size))
reward, done, zero = jp.zeros(3)
metrics = {
"reward_linvel": zero,
"reward_quadctrl": zero,
"reward_alive": zero,
"reward_impact": zero,
}
return State(qp, obs, reward, done, metrics)
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
qp, info = self.sys.step(state.qp, action)
obs = self._get_obs(qp, info, action)
# change compare to humanoid
pos_before = state.qp.pos[:-2] # ignore floor & trap at last index
pos_after = qp.pos[:-2] # ignore floor & trap at last index
com_before = jp.sum(pos_before * self.mass, axis=0) / jp.sum(self.mass)
com_after = jp.sum(pos_after * self.mass, axis=0) / jp.sum(self.mass)
lin_vel_cost = 1.25 * (com_after[0] - com_before[0]) / self.sys.config.dt
quad_ctrl_cost = 0.01 * jp.sum(jp.square(action))
# can ignore contact cost, see: https://github.com/openai/gym/issues/1541
quad_impact_cost = jp.float32(0)
alive_bonus = jp.float32(5)
reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
done = jp.where(qp.pos[0, 2] < 0.8, jp.float32(1), jp.float32(0))
done = jp.where(qp.pos[0, 2] > 2.1, jp.float32(1), done)
state.metrics.update(
reward_linvel=lin_vel_cost,
reward_quadctrl=quad_ctrl_cost,
reward_alive=alive_bonus,
reward_impact=quad_impact_cost,
)
return state.replace(qp=qp, obs=obs, reward=reward, done=done)
def _get_obs(self, qp: brax.QP, info: brax.Info, action: jp.ndarray) -> jp.ndarray:
"""Observe humanoid body position, velocities, and angles."""
# some pre-processing to pull joint angles and velocities
joint_obs = [j.angle_vel(qp) for j in self.sys.joints]
# qpos:
# Z of the torso (1,)
# orientation of the torso as quaternion (4,)
# joint angles (8,)
joint_angles = [jp.array(j[0]).reshape(-1) for j in joint_obs]
qpos = [
qp.pos[0, 2:],
qp.rot[0],
] + joint_angles
# qvel:
# velocity of the torso (3,)
# angular velocity of the torso (3,)
# joint angle velocities (8,)
joint_velocities = [jp.array(j[1]).reshape(-1) for j in joint_obs]
qvel = [
qp.vel[0],
qp.ang[0],
] + joint_velocities
# actuator forces
qfrc_actuator = []
for act in self.sys.actuators:
torque = jp.take(action, act.act_index)
torque = torque.reshape(torque.shape[:-2] + (-1,))
torque *= jp.repeat(act.strength, act.act_index.shape[-1])
qfrc_actuator.append(torque)
# external contact forces:
# delta velocity (3,), delta ang (3,) * num bodies in the system
cfrc_ext = [info.contact.vel, info.contact.ang]
# flatten bottom dimension
cfrc_ext = [x.reshape(x.shape[:-2] + (-1,)) for x in cfrc_ext]
# center of mass obs: - change compare to humanoid
body_pos = qp.pos[:-2] # ignore floor & target at last index
body_vel = qp.vel[:-2] # ignore floor & target at last index
com_vec = jp.sum(body_pos * self.mass, axis=0) / jp.sum(self.mass)
com_vel = body_vel * self.mass / jp.sum(self.mass)
v_outer = jp.vmap(lambda a: jp.outer(a, a))
v_cross = jp.vmap(jp.cross)
disp_vec = body_pos - com_vec
com_inert = self.inertia_matrix + self.mass.reshape((11, 1, 1)) * (
(jp.norm(disp_vec, axis=1) ** 2.0).reshape((11, 1, 1))
* jp.stack([jp.eye(3)] * 11)
- v_outer(disp_vec)
)
cinert = [com_inert.reshape(-1)]
square_disp = (1e-7 + (jp.norm(disp_vec, axis=1) ** 2.0)).reshape((11, 1))
com_angular_vel = v_cross(disp_vec, body_vel) / square_disp
cvel = [com_vel.reshape(-1), com_angular_vel.reshape(-1)]
return jp.concatenate(qpos + qvel + cinert + cvel + qfrc_actuator + cfrc_ext)
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/humanoidtrap.py
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
qpos = self.sys.default_angle() + jp.random_uniform(
rng1, (self.sys.num_joint_dof,), -0.01, 0.01
)
qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.01, 0.01)
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
info = self.sys.info(qp)
obs = self._get_obs(qp, info, jp.zeros(self.action_size))
reward, done, zero = jp.zeros(3)
metrics = {
"reward_linvel": zero,
"reward_quadctrl": zero,
"reward_alive": zero,
"reward_impact": zero,
}
return State(qp, obs, reward, done, metrics)
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/humanoidtrap.py
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
qp, info = self.sys.step(state.qp, action)
obs = self._get_obs(qp, info, action)
# change compare to humanoid
pos_before = state.qp.pos[:-2] # ignore floor & trap at last index
pos_after = qp.pos[:-2] # ignore floor & trap at last index
com_before = jp.sum(pos_before * self.mass, axis=0) / jp.sum(self.mass)
com_after = jp.sum(pos_after * self.mass, axis=0) / jp.sum(self.mass)
lin_vel_cost = 1.25 * (com_after[0] - com_before[0]) / self.sys.config.dt
quad_ctrl_cost = 0.01 * jp.sum(jp.square(action))
# can ignore contact cost, see: https://github.com/openai/gym/issues/1541
quad_impact_cost = jp.float32(0)
alive_bonus = jp.float32(5)
reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
done = jp.where(qp.pos[0, 2] < 0.8, jp.float32(1), jp.float32(0))
done = jp.where(qp.pos[0, 2] > 2.1, jp.float32(1), done)
state.metrics.update(
reward_linvel=lin_vel_cost,
reward_quadctrl=quad_ctrl_cost,
reward_alive=alive_bonus,
reward_impact=quad_impact_cost,
)
return state.replace(qp=qp, obs=obs, reward=reward, done=done)
init_state_wrapper
¶
FixedInitialStateWrapper (Wrapper)
¶
Wrapper to make the initial state of the environment deterministic and fixed. This is done by removing the random noise from the DoF positions and velocities.
Source code in qdax/environments/init_state_wrapper.py
class FixedInitialStateWrapper(Wrapper):
"""Wrapper to make the initial state of the environment deterministic and fixed.
This is done by removing the random noise from the DoF positions and velocities.
"""
def __init__(
self,
env: Env,
base_env_name: str,
get_obs_fn: Optional[
Callable[[brax.QP, brax.Info, jp.ndarray], jp.ndarray]
] = None,
):
env_get_obs = {
"ant": lambda qp, info, action: self._get_obs(qp, info),
"halfcheetah": lambda qp, info, action: self._get_obs(qp, info),
"walker2d": lambda qp, info, action: self._get_obs(qp),
"hopper": lambda qp, info, action: self._get_obs(qp),
"humanoid": lambda qp, info, action: self._get_obs(qp, info, action),
}
super().__init__(env)
if get_obs_fn is not None:
self._get_obs_fn = get_obs_fn
elif base_env_name in env_get_obs.keys():
self._get_obs_fn = env_get_obs[base_env_name]
else:
raise NotImplementedError(
f"This wrapper does not support {base_env_name} yet."
)
def reset(self, rng: jp.ndarray) -> State:
"""Reset the state of the environment with a deterministic and fixed
initial state.
Args:
rng: random key to handle stochastic operations. Used by the parent
init reset function.
Returns:
A new state with a fixed observation.
"""
# Run the default reset method of parent environment
state = self.env.reset(rng)
# Compute new initial positions and velocities
qpos = self.sys.default_angle()
qvel = jp.zeros((self.sys.num_joint_dof,))
# update qd
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
# get the new obs
obs = self._get_obs_fn(qp, self.sys.info(qp), jp.zeros(self.action_size))
return state.replace(qp=qp, obs=obs)
reset(self, rng)
¶
Reset the state of the environment with a deterministic and fixed initial state.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/environments/init_state_wrapper.py
def reset(self, rng: jp.ndarray) -> State:
"""Reset the state of the environment with a deterministic and fixed
initial state.
Args:
rng: random key to handle stochastic operations. Used by the parent
init reset function.
Returns:
A new state with a fixed observation.
"""
# Run the default reset method of parent environment
state = self.env.reset(rng)
# Compute new initial positions and velocities
qpos = self.sys.default_angle()
qvel = jp.zeros((self.sys.num_joint_dof,))
# update qd
qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
# get the new obs
obs = self._get_obs_fn(qp, self.sys.info(qp), jp.zeros(self.action_size))
return state.replace(qp=qp, obs=obs)
locomotion_wrappers
¶
QDSystem (System)
¶
Inheritance of brax physic system.
Work precisely the same but store some information from the physical simulation in the aux_info attribute.
This is used in FeetContactWrapper to get the feet contact of the robot with the ground.
Source code in qdax/environments/locomotion_wrappers.py
class QDSystem(System):
"""Inheritance of brax physic system.
Work precisely the same but store some information from the physical
simulation in the aux_info attribute.
This is used in FeetContactWrapper to get the feet contact of the
robot with the ground.
"""
def __init__(
self, config: config_pb2.Config, resource_paths: Optional[Sequence[str]] = None
):
super().__init__(config, resource_paths=resource_paths)
self.aux_info = None
def step(self, qp: QP, act: jp.ndarray) -> Tuple[QP, Info]:
qp, info = super().step(qp, act)
self.aux_info = info
return qp, info
step(self, qp, act)
¶
Generic step function. Overridden with appropriate step at init.
Source code in qdax/environments/locomotion_wrappers.py
def step(self, qp: QP, act: jp.ndarray) -> Tuple[QP, Info]:
qp, info = super().step(qp, act)
self.aux_info = info
return qp, info
FeetContactWrapper (QDEnv)
¶
Wraps gym environments to add the feet contact data.
Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the feet_contact booleans in the information dictionary of the Brax.state.
The only supported envs at the moment are among the classic locomotion envs : Walker2D, Hopper, Ant, Bullet.
New locomotions envs can easily be added by adding the config name of the feet of the corresponding environment in the FEET_NAME dictionary.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = FeetContactWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
# retrieve feet contact
feet_contact = state.info["state_descriptor"]
# do whatever you want with feet_contact
print(f"Feet contact : {feet_contact}")
Source code in qdax/environments/locomotion_wrappers.py
class FeetContactWrapper(QDEnv):
"""Wraps gym environments to add the feet contact data.
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply add the feet_contact booleans in
the information dictionary of the Brax.state.
The only supported envs at the moment are among the classic
locomotion envs : Walker2D, Hopper, Ant, Bullet.
New locomotions envs can easily be added by adding the config name
of the feet of the corresponding environment in the FEET_NAME dictionary.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = FeetContactWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
# retrieve feet contact
feet_contact = state.info["state_descriptor"]
# do whatever you want with feet_contact
print(f"Feet contact : {feet_contact}")
"""
def __init__(self, env: Env, env_name: str):
if env_name not in FEET_NAMES.keys():
raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
super().__init__(config=None)
self.env = env
self._env_name = env_name
if hasattr(self.env, "sys"):
self.env.sys = QDSystem(self.env.sys.config)
self._feet_contact_idx = jp.array(
[env.sys.body.index.get(name) for name in FEET_NAMES[env_name]]
)
@property
def state_descriptor_length(self) -> int:
return self.behavior_descriptor_length
@property
def state_descriptor_name(self) -> str:
return "feet_contact"
@property
def state_descriptor_limits(self) -> Tuple[List, List]:
return self.behavior_descriptor_limits
@property
def behavior_descriptor_length(self) -> int:
return len(self._feet_contact_idx)
@property
def behavior_descriptor_limits(self) -> Tuple[List, List]:
bd_length = self.behavior_descriptor_length
return (jnp.zeros((bd_length,)), jnp.ones((bd_length,)))
@property
def name(self) -> str:
return self._env_name
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["state_descriptor"] = self._get_feet_contact(
self.env.sys.info(state.qp)
)
return state
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
state.info["state_descriptor"] = self._get_feet_contact(self.env.sys.aux_info)
return state
def _get_feet_contact(self, info: Info) -> jp.ndarray:
contacts = info.contact.vel
return jp.any(contacts[self._feet_contact_idx], axis=1).astype(jp.float32)
@property
def unwrapped(self) -> Env:
return self.env.unwrapped
def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self.env, name)
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/locomotion_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["state_descriptor"] = self._get_feet_contact(
self.env.sys.info(state.qp)
)
return state
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
state.info["state_descriptor"] = self._get_feet_contact(self.env.sys.aux_info)
return state
XYPositionWrapper (QDEnv)
¶
Wraps gym environments to add the position data.
Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the actual position in the information dictionary of the Brax.state.
One can also add values to clip the state descriptors.
The only supported envs at the moment are among the classic locomotion envs : Ant, Humanoid.
New locomotions envs can easily be added by adding the config name of the feet of the corresponding environment in the STATE_POSITION dictionary.
RMQ: this can be used with Hopper, Walker2d, Halfcheetah but it makes less sens as those are limited to one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = XYPositionWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
# retrieve feet contact
xy_position = state.info["xy_position"]
# do whatever you want with xy_position
print(f"xy position : {xy_position}")
Source code in qdax/environments/locomotion_wrappers.py
class XYPositionWrapper(QDEnv):
"""Wraps gym environments to add the position data.
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply add the actual position in
the information dictionary of the Brax.state.
One can also add values to clip the state descriptors.
The only supported envs at the moment are among the classic
locomotion envs : Ant, Humanoid.
New locomotions envs can easily be added by adding the config name
of the feet of the corresponding environment in the STATE_POSITION
dictionary.
RMQ: this can be used with Hopper, Walker2d, Halfcheetah but it makes
less sens as those are limited to one direction.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = XYPositionWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
# retrieve feet contact
xy_position = state.info["xy_position"]
# do whatever you want with xy_position
print(f"xy position : {xy_position}")
"""
def __init__(
self,
env: Env,
env_name: str,
minval: Optional[List[float]] = None,
maxval: Optional[List[float]] = None,
):
if env_name not in COG_NAMES.keys():
raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
super().__init__(config=None)
self.env = env
self._env_name = env_name
if hasattr(self.env, "sys"):
self._cog_idx = self.env.sys.body.index[COG_NAMES[env_name]]
else:
raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
if minval is None:
minval = jnp.ones((2,)) * (-jnp.inf)
if maxval is None:
maxval = jnp.ones((2,)) * jnp.inf
if len(minval) == 2 and len(maxval) == 2:
self._minval = jnp.array(minval)
self._maxval = jnp.array(maxval)
else:
raise NotImplementedError(
"Please make sure to give two values for each limits."
)
@property
def state_descriptor_length(self) -> int:
return 2
@property
def state_descriptor_name(self) -> str:
return "xy_position"
@property
def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return self._minval, self._maxval
@property
def behavior_descriptor_length(self) -> int:
return self.state_descriptor_length
@property
def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return self.state_descriptor_limits
@property
def name(self) -> str:
return self._env_name
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["state_descriptor"] = jnp.clip(
state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
)
return state
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# get xy position of the center of gravity
state.info["state_descriptor"] = jnp.clip(
state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
)
return state
@property
def unwrapped(self) -> Env:
return self.env.unwrapped
def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self.env, name)
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/locomotion_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
state = self.env.reset(rng)
state.info["state_descriptor"] = jnp.clip(
state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
)
return state
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# get xy position of the center of gravity
state.info["state_descriptor"] = jnp.clip(
state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
)
return state
NoForwardRewardWrapper (Wrapper)
¶
Wraps gym environments to remove forward reward.
Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply remove the forward speed term of the reward.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = NoForwardRewardWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
Source code in qdax/environments/locomotion_wrappers.py
class NoForwardRewardWrapper(Wrapper):
"""Wraps gym environments to remove forward reward.
Utilisation is simple: create an environment with Brax, pass
it to the wrapper with the name of the environment, and it will
work like before and will simply remove the forward speed term
of the reward.
Example :
from brax import envs
from brax import jumpy as jp
# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = NoForwardRewardWrapper(env, ENV_NAME)
state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
action = jp.zeros((qd_env.action_size,))
state = qd_env.step(state, action)
"""
def __init__(self, env: Env, env_name: str) -> None:
if env_name not in FORWARD_REWARD_NAMES.keys():
raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
super().__init__(env)
self._env_name = env_name
self._fd_reward_field = FORWARD_REWARD_NAMES[env_name]
@property
def name(self) -> str:
return self._env_name
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# update the reward (remove forward_reward)
new_reward = state.reward - state.metrics[self._fd_reward_field]
return state.replace(reward=new_reward) # type: ignore
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state = self.env.step(state, action)
# update the reward (remove forward_reward)
new_reward = state.reward - state.metrics[self._fd_reward_field]
return state.replace(reward=new_reward) # type: ignore
pointmaze
¶
PointMaze (Env)
¶
Jax/Brax implementation of the PointMaze. Highly inspired from the old python implementation of the PointMaze.
In order to stay in the Brax API, I will use a fake QP at several moment of the implementation. This enable to use the brax.envs.State from Brax. To avoid this, it would be good to ask Brax to enlarge a bit their API for environments that are not physically simulated.
Source code in qdax/environments/pointmaze.py
class PointMaze(Env):
"""Jax/Brax implementation of the PointMaze.
Highly inspired from the old python implementation of
the PointMaze.
In order to stay in the Brax API, I will use a fake QP
at several moment of the implementation. This enable to
use the brax.envs.State from Brax. To avoid this,
it would be good to ask Brax to enlarge a bit their API
for environments that are not physically simulated.
"""
def __init__(
self,
scale_action_space: float = 10,
x_min: float = -1,
x_max: float = 1,
y_min: float = -1,
y_max: float = 1,
zone_width: float = 0.1,
zone_width_offset_from_x_min: float = 0.5,
zone_height_offset_from_y_max: float = -0.2,
wall_width_ratio: float = 0.75,
upper_wall_height_offset: float = 0.2,
lower_wall_height_offset: float = -0.5,
**kwargs: Any,
) -> None:
super().__init__(None, **kwargs)
self._scale_action_space = scale_action_space
self._x_min = x_min
self._x_max = x_max
self._y_min = y_min
self._y_max = y_max
self._low = jp.array([self._x_min, self._y_min], dtype=jp.float32)
self._high = jp.array([self._x_max, self._y_max], dtype=jp.float32)
self.n_zones = 1
self.zone_width = zone_width
self.zone_width_offset = self._x_min + zone_width_offset_from_x_min
self.zone_height_offset = self._y_max + zone_height_offset_from_y_max
self.viewer = None
# Walls
self.wallheight = 0.01
self.wallwidth = (self._x_max - self._x_min) * wall_width_ratio
self.upper_wall_width_offset = self._x_min + self.wallwidth / 2
self.upper_wall_height_offset = upper_wall_height_offset
self.lower_wall_width_offset = self._x_max - self.wallwidth / 2
self.lower_wall_height_offset = lower_wall_height_offset
@property
def descriptors_min_values(self) -> List[float]:
"""Minimum values for descriptors."""
return [self._x_min, self._y_min]
@property
def descriptors_max_values(self) -> List[float]:
"""Maximum values for descriptors."""
return [self._x_max, self._y_max]
@property
def descriptors_names(self) -> List[str]:
"""Descriptors names."""
return ["x_pos", "y_pos"]
@property
def state_descriptor_length(self) -> int:
return 2
@property
def state_descriptor_name(self) -> str:
return "xy_position"
@property
def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return [self._x_min, self._y_min], [self._x_max, self._y_max]
@property
def behavior_descriptor_length(self) -> int:
return self.state_descriptor_length
@property
def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
return self.state_descriptor_limits
@property
def action_size(self) -> int:
"""The size of the observation vector returned in step and reset."""
return 2
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
# get initial position - reproduce the old implementation
x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10
y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7)
obs_init = jp.array([x_init, y_init])
# create fake qp (to re-use brax.State)
fake_qp = brax.QP.zero()
# init reward, metrics and infos
reward, done = jp.zeros(2)
metrics: Dict = {}
# managing state descriptor by our own
info_init = {"state_descriptor": obs_init}
return State(fake_qp, obs_init, reward, done, metrics, info_init)
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
# clip action taken
min_action = self._low
max_action = self._high
action = jp.clip(action, min_action, max_action) / self._scale_action_space
# get the current position
x_pos_old, y_pos_old = state.obs
# compute the new position
x_pos = x_pos_old + action[0]
y_pos = y_pos_old + action[1]
# take into account a potential wall collision
y_pos = self._collision_lower_wall(y_pos, y_pos_old, x_pos, x_pos_old)
y_pos = self._collision_upper_wall(y_pos, y_pos_old, x_pos, x_pos_old)
# take into account border walls
x_pos = jp.clip(x_pos, jp.array(self._x_min), jp.array(self._x_max))
y_pos = jp.clip(y_pos, jp.array(self._y_min), jp.array(self._y_max))
reward = -jp.norm(
jp.array([x_pos - self.zone_width_offset, y_pos - self.zone_height_offset])
)
# determine if zone was reached
in_zone = self._in_zone(x_pos, y_pos)
done = jp.where(
jp.array(in_zone),
x=jp.array(1.0),
y=jp.array(0.0),
)
new_obs = jp.array([x_pos, y_pos])
# update state descriptor
state.info["state_descriptor"] = new_obs
# update the state
return state.replace(obs=new_obs, reward=reward, done=done) # type: ignore
def _in_zone(self, x_pos: jp.ndarray, y_pos: jp.ndarray) -> Union[bool, jp.ndarray]:
"""Determine if the point reached the goal area."""
zone_center_width, zone_center_height = (
self.zone_width_offset,
self.zone_height_offset,
)
condition_1 = zone_center_width - self.zone_width / 2 <= x_pos
condition_2 = x_pos <= zone_center_width + self.zone_width / 2
condition_3 = zone_center_height - self.zone_width / 2 <= y_pos
condition_4 = y_pos <= zone_center_height + self.zone_width / 2
return condition_1 & condition_2 & condition_3 & condition_4
def _collision_lower_wall(
self,
y_pos: jp.ndarray,
y_pos_old: jp.ndarray,
x_pos: jp.ndarray,
x_pos_old: jp.ndarray,
) -> jp.ndarray:
"""Manage potential collisions with the walls."""
# global conditions on the x axis contacts
x_hitting_wall = (self.lower_wall_height_offset - y_pos_old) / (
y_pos - y_pos_old
) * (x_pos - x_pos_old) + x_pos_old
x_axis_contact_condition = x_hitting_wall >= self._x_max - self.wallwidth
# From down - boolean style
y_axis_down_contact_condition_1 = y_pos_old <= self.lower_wall_height_offset
y_axis_down_contact_condition_2 = self.lower_wall_height_offset < y_pos
# y_pos update
new_y_pos = jp.where(
y_axis_down_contact_condition_1
& y_axis_down_contact_condition_2
& x_axis_contact_condition,
x=jp.array(self.lower_wall_height_offset),
y=y_pos,
)
# From up - boolean style
y_axis_up_contact_condition_1 = (
y_pos < self.lower_wall_height_offset + self.wallheight
)
y_axis_up_contact_condition_2 = (
self.lower_wall_height_offset + self.wallheight <= y_pos_old
)
y_axis_up_contact_condition_3 = y_pos_old < self.upper_wall_height_offset
# y_pos update
new_y_pos = jp.where(
y_axis_up_contact_condition_1
& y_axis_up_contact_condition_2
& y_axis_up_contact_condition_3
& x_axis_contact_condition,
x=jp.array(self.lower_wall_height_offset + self.wallheight),
y=new_y_pos,
)
return new_y_pos
def _collision_upper_wall(
self,
y_pos: jp.ndarray,
y_pos_old: jp.ndarray,
x_pos: jp.ndarray,
x_pos_old: jp.ndarray,
) -> jp.ndarray:
"""Manage potential collisions with the walls."""
# global conditions on the x axis contacts
x_hitting_wall = (self.upper_wall_height_offset - y_pos_old) / (
y_pos - y_pos_old
) * (x_pos - x_pos_old) + x_pos_old
x_axis_contact_condition = x_hitting_wall <= self._x_min + self.wallwidth
# From up - boolean style
y_axis_up_contact_condition_1 = (
y_pos_old >= self.upper_wall_height_offset + self.wallheight
)
y_axis_up_contact_condition_2 = (
self.upper_wall_height_offset + self.wallheight > y_pos
)
# y_pos update
new_y_pos = jp.where(
y_axis_up_contact_condition_1
& y_axis_up_contact_condition_2
& x_axis_contact_condition,
x=jp.array(self.upper_wall_height_offset + self.wallheight),
y=y_pos,
)
# From down - boolean style
y_axis_down_contact_condition_1 = y_pos > self.upper_wall_height_offset
y_axis_down_contact_condition_2 = self.upper_wall_height_offset >= y_pos_old
y_axis_down_contact_condition_3 = y_pos_old > self.lower_wall_height_offset
# y_pos update
new_y_pos = jp.where(
y_axis_down_contact_condition_1
& y_axis_down_contact_condition_2
& y_axis_down_contact_condition_3
& x_axis_contact_condition,
x=jp.array(self.upper_wall_height_offset),
y=new_y_pos,
)
return new_y_pos
descriptors_min_values: List[float]
property
readonly
¶
Minimum values for descriptors.
descriptors_max_values: List[float]
property
readonly
¶
Maximum values for descriptors.
descriptors_names: List[str]
property
readonly
¶
Descriptors names.
action_size: int
property
readonly
¶
The size of the observation vector returned in step and reset.
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/pointmaze.py
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jp.random_split(rng, 3)
# get initial position - reproduce the old implementation
x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10
y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7)
obs_init = jp.array([x_init, y_init])
# create fake qp (to re-use brax.State)
fake_qp = brax.QP.zero()
# init reward, metrics and infos
reward, done = jp.zeros(2)
metrics: Dict = {}
# managing state descriptor by our own
info_init = {"state_descriptor": obs_init}
return State(fake_qp, obs_init, reward, done, metrics, info_init)
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/pointmaze.py
def step(self, state: State, action: jp.ndarray) -> State:
"""Run one timestep of the environment's dynamics."""
# clip action taken
min_action = self._low
max_action = self._high
action = jp.clip(action, min_action, max_action) / self._scale_action_space
# get the current position
x_pos_old, y_pos_old = state.obs
# compute the new position
x_pos = x_pos_old + action[0]
y_pos = y_pos_old + action[1]
# take into account a potential wall collision
y_pos = self._collision_lower_wall(y_pos, y_pos_old, x_pos, x_pos_old)
y_pos = self._collision_upper_wall(y_pos, y_pos_old, x_pos, x_pos_old)
# take into account border walls
x_pos = jp.clip(x_pos, jp.array(self._x_min), jp.array(self._x_max))
y_pos = jp.clip(y_pos, jp.array(self._y_min), jp.array(self._y_max))
reward = -jp.norm(
jp.array([x_pos - self.zone_width_offset, y_pos - self.zone_height_offset])
)
# determine if zone was reached
in_zone = self._in_zone(x_pos, y_pos)
done = jp.where(
jp.array(in_zone),
x=jp.array(1.0),
y=jp.array(0.0),
)
new_obs = jp.array([x_pos, y_pos])
# update state descriptor
state.info["state_descriptor"] = new_obs
# update the state
return state.replace(obs=new_obs, reward=reward, done=done) # type: ignore
wrappers
¶
CompletedEvalMetrics (PyTreeNode)
dataclass
¶
CompletedEvalMetrics(current_episode_metrics: Dict[str, Union[numpy.ndarray, jax.Array]], completed_episodes_metrics: Dict[str, Union[numpy.ndarray, jax.Array]], completed_episodes: Union[numpy.ndarray, jax.Array], completed_episodes_steps: Union[numpy.ndarray, jax.Array])
Source code in qdax/environments/wrappers.py
class CompletedEvalMetrics(flax.struct.PyTreeNode):
current_episode_metrics: Dict[str, jp.ndarray]
completed_episodes_metrics: Dict[str, jp.ndarray]
completed_episodes: jp.ndarray
completed_episodes_steps: jp.ndarray
replace(self, **updates)
¶
"Returns a new object replacing the specified fields with new values.
Source code in qdax/environments/wrappers.py
def replace(self, **updates):
""" "Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
CompletedEvalWrapper (Wrapper)
¶
Brax env with eval metrics for completed episodes.
Source code in qdax/environments/wrappers.py
class CompletedEvalWrapper(Wrapper):
"""Brax env with eval metrics for completed episodes."""
STATE_INFO_KEY = "completed_eval_metrics"
def reset(self, rng: jp.ndarray) -> State:
reset_state = self.env.reset(rng)
reset_state.metrics["reward"] = reset_state.reward
eval_metrics = CompletedEvalMetrics(
current_episode_metrics=jax.tree_util.tree_map(
jp.zeros_like, reset_state.metrics
),
completed_episodes_metrics=jax.tree_util.tree_map(
lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics
),
completed_episodes=jp.zeros(()),
completed_episodes_steps=jp.zeros(()),
)
reset_state.info[self.STATE_INFO_KEY] = eval_metrics
return reset_state
def step(self, state: State, action: jp.ndarray) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
del state.info[self.STATE_INFO_KEY]
nstate = self.env.step(state, action)
nstate.metrics["reward"] = nstate.reward
# steps stores the highest step reached when done = True, and then
# the next steps becomes action_repeat
completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum(
nstate.info["steps"] * nstate.done
)
current_episode_metrics = jax.tree_util.tree_map(
lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics
)
completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done)
completed_episodes_metrics = jax.tree_util.tree_map(
lambda a, b: a + jp.sum(b * nstate.done),
state_metrics.completed_episodes_metrics,
current_episode_metrics,
)
current_episode_metrics = jax.tree_util.tree_map(
lambda a, b: a * (1 - nstate.done) + b * nstate.done,
current_episode_metrics,
nstate.metrics,
)
eval_metrics = CompletedEvalMetrics(
current_episode_metrics=current_episode_metrics,
completed_episodes_metrics=completed_episodes_metrics,
completed_episodes=completed_episodes,
completed_episodes_steps=completed_episodes_steps,
)
nstate.info[self.STATE_INFO_KEY] = eval_metrics
return nstate
reset(self, rng)
¶
Resets the environment to an initial state.
Source code in qdax/environments/wrappers.py
def reset(self, rng: jp.ndarray) -> State:
reset_state = self.env.reset(rng)
reset_state.metrics["reward"] = reset_state.reward
eval_metrics = CompletedEvalMetrics(
current_episode_metrics=jax.tree_util.tree_map(
jp.zeros_like, reset_state.metrics
),
completed_episodes_metrics=jax.tree_util.tree_map(
lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics
),
completed_episodes=jp.zeros(()),
completed_episodes_steps=jp.zeros(()),
)
reset_state.info[self.STATE_INFO_KEY] = eval_metrics
return reset_state
step(self, state, action)
¶
Run one timestep of the environment's dynamics.
Source code in qdax/environments/wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
state_metrics = state.info[self.STATE_INFO_KEY]
if not isinstance(state_metrics, CompletedEvalMetrics):
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
del state.info[self.STATE_INFO_KEY]
nstate = self.env.step(state, action)
nstate.metrics["reward"] = nstate.reward
# steps stores the highest step reached when done = True, and then
# the next steps becomes action_repeat
completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum(
nstate.info["steps"] * nstate.done
)
current_episode_metrics = jax.tree_util.tree_map(
lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics
)
completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done)
completed_episodes_metrics = jax.tree_util.tree_map(
lambda a, b: a + jp.sum(b * nstate.done),
state_metrics.completed_episodes_metrics,
current_episode_metrics,
)
current_episode_metrics = jax.tree_util.tree_map(
lambda a, b: a * (1 - nstate.done) + b * nstate.done,
current_episode_metrics,
nstate.metrics,
)
eval_metrics = CompletedEvalMetrics(
current_episode_metrics=current_episode_metrics,
completed_episodes_metrics=completed_episodes_metrics,
completed_episodes=completed_episodes,
completed_episodes_steps=completed_episodes_steps,
)
nstate.info[self.STATE_INFO_KEY] = eval_metrics
return nstate