Environments

qdax.environments special

create(env_name, episode_length=1000, action_repeat=1, auto_reset=True, batch_size=None, eval_metrics=False, fixed_init_state=False, qdax_wrappers_kwargs=None, **kwargs)

Creates an Env with a specified brax system. Please use namespace to avoid confusion between this function and brax.envs.create.

Source code in qdax/environments/__init__.py
def create(
    env_name: str,
    episode_length: int = 1000,
    action_repeat: int = 1,
    auto_reset: bool = True,
    batch_size: Optional[int] = None,
    eval_metrics: bool = False,
    fixed_init_state: bool = False,
    qdax_wrappers_kwargs: Optional[List] = None,
    **kwargs: Any,
) -> Union[Env, QDEnv]:
    """Creates an Env with a specified brax system.
    Please use namespace to avoid confusion between this function and
    brax.envs.create.
    """

    if env_name in _envs.keys():
        env = _envs[env_name](legacy_spring=True, **kwargs)
    elif env_name in _qdax_envs.keys():
        env = _qdax_envs[env_name](**kwargs)
    elif env_name in _qdax_custom_envs.keys():
        base_env_name = _qdax_custom_envs[env_name]["env"]
        if base_env_name in _envs.keys():
            env = _envs[base_env_name](legacy_spring=True, **kwargs)
        elif base_env_name in _qdax_envs.keys():
            env = _qdax_envs[base_env_name](**kwargs)  # type: ignore
    else:
        raise NotImplementedError("This environment name does not exist!")

    if env_name in _qdax_custom_envs.keys():
        # roll with qdax wrappers
        wrappers = _qdax_custom_envs[env_name]["wrappers"]
        if qdax_wrappers_kwargs is None:
            kwargs_list = _qdax_custom_envs[env_name]["kwargs"]
        else:
            kwargs_list = qdax_wrappers_kwargs
        for wrapper, kwargs in zip(wrappers, kwargs_list):  # type: ignore
            env = wrapper(env, base_env_name, **kwargs)  # type: ignore

    if episode_length is not None:
        env = EpisodeWrapper(env, episode_length, action_repeat)
    if batch_size:
        env = VectorWrapper(env, batch_size)
    if fixed_init_state:
        # retrieve the base env
        if env_name not in _qdax_custom_envs.keys():
            base_env_name = env_name
        # wrap the env
        env = FixedInitialStateWrapper(env, base_env_name=base_env_name)  # type: ignore
    if auto_reset:
        env = AutoResetWrapper(env)
        if env_name in _qdax_custom_envs.keys():
            env = StateDescriptorResetWrapper(env)
    if eval_metrics:
        env = EvalWrapper(env)
        env = CompletedEvalWrapper(env)

    return env

create_fn(env_name, **kwargs)

Returns a function that when called, creates an Env. Please use namespace to avoid confusion between this function and brax.envs.create_fn.

Source code in qdax/environments/__init__.py
def create_fn(env_name: str, **kwargs: Any) -> Callable[..., Env]:
    """Returns a function that when called, creates an Env.
    Please use namespace to avoid confusion between this function and
    brax.envs.create_fn.
    """
    return functools.partial(create, env_name, **kwargs)

base_wrappers

QDEnv (Env)

Wrapper for all QD environments.

Source code in qdax/environments/base_wrappers.py
class QDEnv(Env):
    """
    Wrapper for all QD environments.
    """

    @property
    @abstractmethod
    def state_descriptor_length(self) -> int:
        pass

    @property
    @abstractmethod
    def state_descriptor_name(self) -> str:
        pass

    @property
    @abstractmethod
    def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        pass

    @property
    @abstractmethod
    def behavior_descriptor_length(self) -> int:
        pass

    @property
    @abstractmethod
    def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        pass

    @property
    @abstractmethod
    def name(self) -> str:
        pass

QDWrapper (QDEnv)

Wrapper for QD environments.

Source code in qdax/environments/base_wrappers.py
class QDWrapper(QDEnv):
    """Wrapper for QD environments."""

    def __init__(self, env: QDEnv):
        super().__init__(config=None)
        self.env = env

    def reset(self, rng: jp.ndarray) -> State:
        return self.env.reset(rng)

    def step(self, state: State, action: jp.ndarray) -> State:
        return self.env.step(state, action)

    @property
    def observation_size(self) -> int:
        return self.env.observation_size  # type: ignore

    @property
    def action_size(self) -> int:
        return self.env.action_size  # type: ignore

    @property
    def state_descriptor_length(self) -> int:
        return self.env.state_descriptor_length

    @property
    def state_descriptor_name(self) -> str:
        return self.env.state_descriptor_name

    @property
    def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return self.env.state_descriptor_limits

    @property
    def behavior_descriptor_length(self) -> int:
        return self.env.behavior_descriptor_length

    @property
    def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return self.env.behavior_descriptor_limits

    @property
    def name(self) -> str:
        return self.env.name

    @property
    def unwrapped(self) -> Env:
        return self.env.unwrapped

    def __getattr__(self, name: str) -> Any:
        if name == "__setstate__":
            raise AttributeError(name)
        return getattr(self.env, name)
observation_size: int property readonly

The size of the observation vector returned in step and reset.

action_size: int property readonly

The size of the action vector expected by step.

reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/base_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    return self.env.reset(rng)
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/base_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    return self.env.step(state, action)

StateDescriptorResetWrapper (QDWrapper)

Automatically resets state descriptors.

Source code in qdax/environments/base_wrappers.py
class StateDescriptorResetWrapper(QDWrapper):
    """Automatically resets state descriptors."""

    def reset(self, rng: jp.ndarray) -> State:
        state = self.env.reset(rng)
        state.info["first_state_descriptor"] = state.info["state_descriptor"]
        return state

    def step(self, state: State, action: jp.ndarray) -> State:

        state = self.env.step(state, action)

        def where_done(x: jp.ndarray, y: jp.ndarray) -> jp.ndarray:
            done = state.done
            if done.shape:
                done = jp.reshape(done, tuple([x.shape[0]] + [1] * (len(x.shape) - 1)))
            return jp.where(done, x, y)

        state.info["state_descriptor"] = where_done(
            state.info["first_state_descriptor"], state.info["state_descriptor"]
        )
        return state
reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/base_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    state = self.env.reset(rng)
    state.info["first_state_descriptor"] = state.info["state_descriptor"]
    return state
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/base_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:

    state = self.env.step(state, action)

    def where_done(x: jp.ndarray, y: jp.ndarray) -> jp.ndarray:
        done = state.done
        if done.shape:
            done = jp.reshape(done, tuple([x.shape[0]] + [1] * (len(x.shape) - 1)))
        return jp.where(done, x, y)

    state.info["state_descriptor"] = where_done(
        state.info["first_state_descriptor"], state.info["state_descriptor"]
    )
    return state

bd_extractors

AuroraExtraInfo (PyTreeNode) dataclass

Information specific to the AURORA algorithm.

Parameters:
  • model_params (Params) – the parameters of the dimensionality reduction model

Source code in qdax/environments/bd_extractors.py
class AuroraExtraInfo(flax.struct.PyTreeNode):
    """
    Information specific to the AURORA algorithm.

    Args:
        model_params: the parameters of the dimensionality reduction model
    """

    model_params: Params
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/environments/bd_extractors.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

AuroraExtraInfoNormalization (AuroraExtraInfo) dataclass

Information specific to the AURORA algorithm. In particular, it contains the normalization parameters for the observations.

Parameters:
  • model_params (Params) – the parameters of the dimensionality reduction model

  • mean_observations (jnp.ndarray) – the mean of observations

  • std_observations (jnp.ndarray) – the std of observations

Source code in qdax/environments/bd_extractors.py
class AuroraExtraInfoNormalization(AuroraExtraInfo):
    """
    Information specific to the AURORA algorithm. In particular, it contains
    the normalization parameters for the observations.

    Args:
        model_params: the parameters of the dimensionality reduction model
        mean_observations: the mean of observations
        std_observations: the std of observations
    """

    mean_observations: jnp.ndarray
    std_observations: jnp.ndarray

    @classmethod
    def create(
        cls,
        model_params: Params,
        mean_observations: jnp.ndarray,
        std_observations: jnp.ndarray,
    ) -> AuroraExtraInfoNormalization:
        return cls(
            model_params=model_params,
            mean_observations=mean_observations,
            std_observations=std_observations,
        )
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/environments/bd_extractors.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

get_final_xy_position(data, mask)

Compute final xy positon.

This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given.

Source code in qdax/environments/bd_extractors.py
def get_final_xy_position(data: QDTransition, mask: jnp.ndarray) -> Descriptor:
    """Compute final xy positon.

    This function suppose that state descriptor is the xy position, as it
    just select the final one of the state descriptors given.
    """
    # reshape mask for bd extraction
    mask = jnp.expand_dims(mask, axis=-1)

    # Get behavior descriptor
    last_index = jnp.int32(jnp.sum(1.0 - mask, axis=1)) - 1
    descriptors = jax.vmap(lambda x, y: x[y])(data.state_desc, last_index)

    # remove the dim coming from the trajectory
    return descriptors.squeeze(axis=1)

get_feet_contact_proportion(data, mask)

Compute feet contact time proportion.

This function suppose that state descriptor is the feet contact, as it just computes the mean of the state descriptors given.

Source code in qdax/environments/bd_extractors.py
def get_feet_contact_proportion(data: QDTransition, mask: jnp.ndarray) -> Descriptor:
    """Compute feet contact time proportion.

    This function suppose that state descriptor is the feet contact, as it
    just computes the mean of the state descriptors given.
    """
    # reshape mask for bd extraction
    mask = jnp.expand_dims(mask, axis=-1)

    # Get behavior descriptor
    descriptors = jnp.sum(data.state_desc * (1.0 - mask), axis=1)
    descriptors = descriptors / jnp.sum(1.0 - mask, axis=1)

    return descriptors

get_aurora_encoding(observations, aurora_extra_info, model)

Compute final aurora embedding.

This function suppose that state descriptor is the xy position, as it just select the final one of the state descriptors given.

Source code in qdax/environments/bd_extractors.py
def get_aurora_encoding(
    observations: jnp.ndarray,
    aurora_extra_info: AuroraExtraInfoNormalization,
    model: flax.linen.Module,
) -> Descriptor:
    """
    Compute final aurora embedding.

    This function suppose that state descriptor is the xy position, as it
    just select the final one of the state descriptors given.
    """
    model_params = aurora_extra_info.model_params
    mean_observations = aurora_extra_info.mean_observations
    std_observations = aurora_extra_info.std_observations

    # lstm seq2seq
    normalized_observations = (observations - mean_observations) / std_observations
    descriptors = model.apply(
        {"params": model_params}, normalized_observations, method=model.encode
    )

    return descriptors.squeeze()

exploration_wrappers

TrapWrapper (Wrapper)

Wraps gym environments to add a Trap in the environment.

Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the Trap to the environment.

This wrapper also adds xy in the observation, as it is an important information for an agent. Now that there is a trap in its env, we expect its actions to depend on its xy position.

The xy position is normalised thanks to the decided limits of the env, which are [0, 30] for x and [-8, 8] for y.

The only supported envs at the moment are among the classic locomotion envs : Ant.

RMQ: Humanoid is not supported yet. RMQ: works for walker2d etc.. but it does not make sens as they can only go in one direction.

Example :

from brax import envs
from brax import jumpy as jp

# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = TrapWrapper(env, ENV_NAME)

state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
    action = jp.zeros((qd_env.action_size,))
    state = qd_env.step(state, action)
Source code in qdax/environments/exploration_wrappers.py
class TrapWrapper(Wrapper):
    """Wraps gym environments to add a Trap in the environment.

    Utilisation is simple: create an environment with Brax, pass
    it to the wrapper with the name of the environment, and it will
    work like before and will simply add the Trap to the environment.

    This wrapper also adds xy in the observation, as it is an important
    information for an agent. Now that there is a trap in its env, we
    expect its actions to depend on its xy position.

    The xy position is normalised thanks to the decided limits of the env,
    which are [0, 30] for x and [-8, 8] for y.

    The only supported envs at the moment are among the classic
    locomotion envs : Ant.

    RMQ: Humanoid is not supported yet.
    RMQ: works for walker2d etc.. but it does not make sens as they
    can only go in one direction.


    Example :

        from brax import envs
        from brax import jumpy as jp

        # choose in ["ant"]
        ENV_NAME = "ant"
        env = envs.create(env_name=ENV_NAME)
        qd_env = TrapWrapper(env, ENV_NAME)

        state = qd_env.reset(rng=jp.random_prngkey(seed=0))
        for i in range(10):
            action = jp.zeros((qd_env.action_size,))
            state = qd_env.step(state, action)


    """

    def __init__(self, env: Env, env_name: str) -> None:
        if (
            env_name not in ENV_SYSTEM_CONFIG.keys()
            or env_name not in COG_NAMES.keys()
            or env_name not in ENV_TRAP_COLLISION.keys()
        ):
            raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
        if env_name not in ["ant", "humanoid"]:
            warnings.warn(
                "Make sure your agent can move in two dimensions!",
                stacklevel=2,
            )
        super().__init__(env)
        self._env_name = env_name
        # update the env config to add the trap
        self._config = (
            ENV_SYSTEM_CONFIG[env_name] + TRAP_CONFIG + ENV_TRAP_COLLISION[env_name]
        )
        # update the associated physical system
        config = text_format.Parse(self._config, brax.Config())
        if not hasattr(self.unwrapped, "sys"):
            raise AttributeError("Cannot link env to a physical system.")
        self.unwrapped.sys = brax.System(config)
        self._cog_idx = self.unwrapped.sys.body.index[COG_NAMES[env_name]]

        # we need to normalise x/y position to avoid values to explose
        self._substract = jnp.array([15, 0])  # come from env limits
        self._divide = jnp.array([15, 8])  # come from env limits

    @property
    def name(self) -> str:
        return self._env_name

    @property
    def observation_size(self) -> int:
        """The size of the observation vector returned in step and reset."""
        rng = jp.random_prngkey(0)
        reset_state = self.reset(rng)
        return int(reset_state.obs.shape[-1])

    def reset(self, rng: jp.ndarray) -> State:
        state = self.env.reset(rng)
        # add xy position to the observation
        xy_pos = state.qp.pos[self._cog_idx][:2]
        # normalise
        xy_pos = (xy_pos - self._substract) / self._divide
        new_obs = jp.concatenate([xy_pos, state.obs])
        return state.replace(obs=new_obs)  # type: ignore

    def step(self, state: State, action: jp.ndarray) -> State:
        state = self.env.step(state, action)
        # add xy position to the observation
        xy_pos = state.qp.pos[self._cog_idx][:2]
        # normalise
        xy_pos = (xy_pos - self._substract) / self._divide
        new_obs = jp.concatenate([xy_pos, state.obs])
        return state.replace(obs=new_obs)  # type: ignore
observation_size: int property readonly

The size of the observation vector returned in step and reset.

reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/exploration_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    state = self.env.reset(rng)
    # add xy position to the observation
    xy_pos = state.qp.pos[self._cog_idx][:2]
    # normalise
    xy_pos = (xy_pos - self._substract) / self._divide
    new_obs = jp.concatenate([xy_pos, state.obs])
    return state.replace(obs=new_obs)  # type: ignore
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/exploration_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state = self.env.step(state, action)
    # add xy position to the observation
    xy_pos = state.qp.pos[self._cog_idx][:2]
    # normalise
    xy_pos = (xy_pos - self._substract) / self._divide
    new_obs = jp.concatenate([xy_pos, state.obs])
    return state.replace(obs=new_obs)  # type: ignore

MazeWrapper (Wrapper)

Wraps gym environments to add a maze in the environment and a new reward (distance to the goal).

Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the Maze to the environment, along with the new reward.

This wrapper also adds xy in the observation, as it is an important information for an agent. Now that the agent is in a maze, we expect its actions to depend on its xy position.

The xy position is normalised thanks to the decided limits of the env, which are [-5, 40] for x and y.

The only supported envs at the moment are among the classic locomotion envs : Ant.

RMQ: Humanoid is not supported yet. RMQ: works for walker2d etc.. but it does not make sens as they can only go in one direction.

Example :

from brax import envs
from brax import jumpy as jp

# choose in ["ant"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = MazeWrapper(env, ENV_NAME)

state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
    action = jp.zeros((qd_env.action_size,))
    state = qd_env.step(state, action)
Source code in qdax/environments/exploration_wrappers.py
class MazeWrapper(Wrapper):
    """Wraps gym environments to add a maze in the environment
    and a new reward (distance to the goal).

    Utilisation is simple: create an environment with Brax, pass
    it to the wrapper with the name of the environment, and it will
    work like before and will simply add the Maze to the environment,
    along with the new reward.

    This wrapper also adds xy in the observation, as it is an important
    information for an agent. Now that the agent is in a maze, we
    expect its actions to depend on its xy position.

    The xy position is normalised thanks to the decided limits of the env,
    which are [-5, 40] for x and y.

    The only supported envs at the moment are among the classic
    locomotion envs : Ant.

    RMQ: Humanoid is not supported yet.
    RMQ: works for walker2d etc.. but it does not make sens as they
    can only go in one direction.

    Example :

        from brax import envs
        from brax import jumpy as jp

        # choose in ["ant"]
        ENV_NAME = "ant"
        env = envs.create(env_name=ENV_NAME)
        qd_env = MazeWrapper(env, ENV_NAME)

        state = qd_env.reset(rng=jp.random_prngkey(seed=0))
        for i in range(10):
            action = jp.zeros((qd_env.action_size,))
            state = qd_env.step(state, action)


    """

    def __init__(self, env: Env, env_name: str) -> None:
        if (
            env_name not in ENV_SYSTEM_CONFIG.keys()
            or env_name not in COG_NAMES.keys()
            or env_name not in ENV_MAZE_COLLISION.keys()
        ):
            raise NotImplementedError(
                f"This wrapper does not support {env_name} yet.",
            )
        if env_name not in ["ant", "humanoid"]:
            warnings.warn(
                "Make sure your agent can move in two dimensions!",
                stacklevel=2,
            )
        super().__init__(env)
        self._env_name = env_name
        self._config = (
            ENV_SYSTEM_CONFIG[env_name] + MAZE_CONFIG + ENV_MAZE_COLLISION[env_name]
        )
        config = text_format.Parse(self._config, brax.Config())
        if not hasattr(self.unwrapped, "sys"):
            raise AttributeError("Cannot link env to a physical system.")
        self.unwrapped.sys = brax.System(config)
        self._cog_idx = self.unwrapped.sys.body.index[COG_NAMES[env_name]]
        self._target_idx = self.unwrapped.sys.body.index["Target"]

        # we need to normalise x/y position to avoid values to explose
        self._substract = jnp.array([17.5, 17.5])  # come from env limits
        self._divide = jnp.array([22.5, 22.5])  # come from env limits

    @property
    def name(self) -> str:
        return self._env_name

    @property
    def observation_size(self) -> int:
        """The size of the observation vector returned in step and reset."""
        rng = jp.random_prngkey(0)
        reset_state = self.reset(rng)
        return int(reset_state.obs.shape[-1])

    def reset(self, rng: jp.ndarray) -> State:
        state = self.env.reset(rng)
        # get xy position of the center of gravity and of the target
        cog_xy_position = state.qp.pos[self._cog_idx][:2]
        target_xy_position = state.qp.pos[self._target_idx][:2]
        # update the reward
        new_reward = -jp.norm(target_xy_position - cog_xy_position)
        # add cog xy position to the observation - normalise
        cog_xy_position = (cog_xy_position - self._substract) / self._divide
        new_obs = jp.concatenate([cog_xy_position, state.obs])
        return state.replace(obs=new_obs, reward=new_reward)  # type: ignore

    def step(self, state: State, action: jp.ndarray) -> State:
        state = self.env.step(state, action)
        # get xy position of the center of gravity and of the target
        cog_xy_position = state.qp.pos[self._cog_idx][:2]
        target_xy_position = state.qp.pos[self._target_idx][:2]
        # update the reward
        new_reward = -jp.norm(target_xy_position - cog_xy_position)
        # add cog xy position to the observation - normalise
        cog_xy_position = (cog_xy_position - self._substract) / self._divide
        new_obs = jp.concatenate([cog_xy_position, state.obs])
        # brax ant suicides by jumping over a manually designed z threshold
        # this line avoid this by increasing the threshold
        done = jp.where(
            state.qp.pos[0, 2] < 0.2,
            x=jp.array(1, dtype=jp.float32),
            y=jp.array(0, dtype=jp.float32),
        )
        done = jp.where(
            state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done
        )
        return state.replace(obs=new_obs, reward=new_reward, done=done)  # type: ignore
observation_size: int property readonly

The size of the observation vector returned in step and reset.

reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/exploration_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    state = self.env.reset(rng)
    # get xy position of the center of gravity and of the target
    cog_xy_position = state.qp.pos[self._cog_idx][:2]
    target_xy_position = state.qp.pos[self._target_idx][:2]
    # update the reward
    new_reward = -jp.norm(target_xy_position - cog_xy_position)
    # add cog xy position to the observation - normalise
    cog_xy_position = (cog_xy_position - self._substract) / self._divide
    new_obs = jp.concatenate([cog_xy_position, state.obs])
    return state.replace(obs=new_obs, reward=new_reward)  # type: ignore
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/exploration_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state = self.env.step(state, action)
    # get xy position of the center of gravity and of the target
    cog_xy_position = state.qp.pos[self._cog_idx][:2]
    target_xy_position = state.qp.pos[self._target_idx][:2]
    # update the reward
    new_reward = -jp.norm(target_xy_position - cog_xy_position)
    # add cog xy position to the observation - normalise
    cog_xy_position = (cog_xy_position - self._substract) / self._divide
    new_obs = jp.concatenate([cog_xy_position, state.obs])
    # brax ant suicides by jumping over a manually designed z threshold
    # this line avoid this by increasing the threshold
    done = jp.where(
        state.qp.pos[0, 2] < 0.2,
        x=jp.array(1, dtype=jp.float32),
        y=jp.array(0, dtype=jp.float32),
    )
    done = jp.where(
        state.qp.pos[0, 2] > 5.0, x=jp.array(1, dtype=jp.float32), y=done
    )
    return state.replace(obs=new_obs, reward=new_reward, done=done)  # type: ignore

humanoidtrap

Trains a humanoid to run in the +x direction. Added a Trap in front. Highly inspired from brax humanoid.

HumanoidTrap (Env)

Trains a humanoid to run in the +x direction.

RMQ: uses legacy spring from Brax.

Source code in qdax/environments/humanoidtrap.py
class HumanoidTrap(Env):
    """Trains a humanoid to run in the +x direction.

    RMQ: uses legacy spring from Brax.
    """

    def __init__(self, **kwargs: Dict[str, Any]) -> None:
        config = _SYSTEM_CONFIG
        # change compared to humanoid
        config = config + TRAP_CONFIG + HUMANOID_TRAP_COLLISIONS
        super().__init__(config=config, **kwargs)
        body = bodies.Body(self.sys.config)
        # change compare to humanoid
        body = jp.take(body, body.idx[:-2])  # skip the floor body & trap body
        self.mass = body.mass.reshape(-1, 1)
        self.inertia = body.inertia
        self.inertia_matrix = jp.array([jp.diag(a) for a in self.inertia])

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jp.random_split(rng, 3)
        qpos = self.sys.default_angle() + jp.random_uniform(
            rng1, (self.sys.num_joint_dof,), -0.01, 0.01
        )
        qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.01, 0.01)
        qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
        info = self.sys.info(qp)
        obs = self._get_obs(qp, info, jp.zeros(self.action_size))
        reward, done, zero = jp.zeros(3)
        metrics = {
            "reward_linvel": zero,
            "reward_quadctrl": zero,
            "reward_alive": zero,
            "reward_impact": zero,
        }
        return State(qp, obs, reward, done, metrics)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Run one timestep of the environment's dynamics."""
        qp, info = self.sys.step(state.qp, action)
        obs = self._get_obs(qp, info, action)

        # change compare to humanoid
        pos_before = state.qp.pos[:-2]  # ignore floor & trap at last index
        pos_after = qp.pos[:-2]  # ignore floor & trap at last index
        com_before = jp.sum(pos_before * self.mass, axis=0) / jp.sum(self.mass)
        com_after = jp.sum(pos_after * self.mass, axis=0) / jp.sum(self.mass)
        lin_vel_cost = 1.25 * (com_after[0] - com_before[0]) / self.sys.config.dt
        quad_ctrl_cost = 0.01 * jp.sum(jp.square(action))
        # can ignore contact cost, see: https://github.com/openai/gym/issues/1541
        quad_impact_cost = jp.float32(0)
        alive_bonus = jp.float32(5)
        reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus

        done = jp.where(qp.pos[0, 2] < 0.8, jp.float32(1), jp.float32(0))
        done = jp.where(qp.pos[0, 2] > 2.1, jp.float32(1), done)
        state.metrics.update(
            reward_linvel=lin_vel_cost,
            reward_quadctrl=quad_ctrl_cost,
            reward_alive=alive_bonus,
            reward_impact=quad_impact_cost,
        )

        return state.replace(qp=qp, obs=obs, reward=reward, done=done)

    def _get_obs(self, qp: brax.QP, info: brax.Info, action: jp.ndarray) -> jp.ndarray:
        """Observe humanoid body position, velocities, and angles."""
        # some pre-processing to pull joint angles and velocities
        joint_obs = [j.angle_vel(qp) for j in self.sys.joints]

        # qpos:
        # Z of the torso (1,)
        # orientation of the torso as quaternion (4,)
        # joint angles (8,)
        joint_angles = [jp.array(j[0]).reshape(-1) for j in joint_obs]
        qpos = [
            qp.pos[0, 2:],
            qp.rot[0],
        ] + joint_angles

        # qvel:
        # velocity of the torso (3,)
        # angular velocity of the torso (3,)
        # joint angle velocities (8,)
        joint_velocities = [jp.array(j[1]).reshape(-1) for j in joint_obs]
        qvel = [
            qp.vel[0],
            qp.ang[0],
        ] + joint_velocities

        # actuator forces
        qfrc_actuator = []
        for act in self.sys.actuators:
            torque = jp.take(action, act.act_index)
            torque = torque.reshape(torque.shape[:-2] + (-1,))
            torque *= jp.repeat(act.strength, act.act_index.shape[-1])
            qfrc_actuator.append(torque)

        # external contact forces:
        # delta velocity (3,), delta ang (3,) * num bodies in the system
        cfrc_ext = [info.contact.vel, info.contact.ang]
        # flatten bottom dimension
        cfrc_ext = [x.reshape(x.shape[:-2] + (-1,)) for x in cfrc_ext]

        # center of mass obs: - change compare to humanoid
        body_pos = qp.pos[:-2]  # ignore floor & target at last index
        body_vel = qp.vel[:-2]  # ignore floor & target at last index

        com_vec = jp.sum(body_pos * self.mass, axis=0) / jp.sum(self.mass)
        com_vel = body_vel * self.mass / jp.sum(self.mass)

        v_outer = jp.vmap(lambda a: jp.outer(a, a))
        v_cross = jp.vmap(jp.cross)

        disp_vec = body_pos - com_vec
        com_inert = self.inertia_matrix + self.mass.reshape((11, 1, 1)) * (
            (jp.norm(disp_vec, axis=1) ** 2.0).reshape((11, 1, 1))
            * jp.stack([jp.eye(3)] * 11)
            - v_outer(disp_vec)
        )

        cinert = [com_inert.reshape(-1)]

        square_disp = (1e-7 + (jp.norm(disp_vec, axis=1) ** 2.0)).reshape((11, 1))
        com_angular_vel = v_cross(disp_vec, body_vel) / square_disp
        cvel = [com_vel.reshape(-1), com_angular_vel.reshape(-1)]

        return jp.concatenate(qpos + qvel + cinert + cvel + qfrc_actuator + cfrc_ext)
reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/humanoidtrap.py
def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jp.random_split(rng, 3)
    qpos = self.sys.default_angle() + jp.random_uniform(
        rng1, (self.sys.num_joint_dof,), -0.01, 0.01
    )
    qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -0.01, 0.01)
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
    info = self.sys.info(qp)
    obs = self._get_obs(qp, info, jp.zeros(self.action_size))
    reward, done, zero = jp.zeros(3)
    metrics = {
        "reward_linvel": zero,
        "reward_quadctrl": zero,
        "reward_alive": zero,
        "reward_impact": zero,
    }
    return State(qp, obs, reward, done, metrics)
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/humanoidtrap.py
def step(self, state: State, action: jp.ndarray) -> State:
    """Run one timestep of the environment's dynamics."""
    qp, info = self.sys.step(state.qp, action)
    obs = self._get_obs(qp, info, action)

    # change compare to humanoid
    pos_before = state.qp.pos[:-2]  # ignore floor & trap at last index
    pos_after = qp.pos[:-2]  # ignore floor & trap at last index
    com_before = jp.sum(pos_before * self.mass, axis=0) / jp.sum(self.mass)
    com_after = jp.sum(pos_after * self.mass, axis=0) / jp.sum(self.mass)
    lin_vel_cost = 1.25 * (com_after[0] - com_before[0]) / self.sys.config.dt
    quad_ctrl_cost = 0.01 * jp.sum(jp.square(action))
    # can ignore contact cost, see: https://github.com/openai/gym/issues/1541
    quad_impact_cost = jp.float32(0)
    alive_bonus = jp.float32(5)
    reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus

    done = jp.where(qp.pos[0, 2] < 0.8, jp.float32(1), jp.float32(0))
    done = jp.where(qp.pos[0, 2] > 2.1, jp.float32(1), done)
    state.metrics.update(
        reward_linvel=lin_vel_cost,
        reward_quadctrl=quad_ctrl_cost,
        reward_alive=alive_bonus,
        reward_impact=quad_impact_cost,
    )

    return state.replace(qp=qp, obs=obs, reward=reward, done=done)

init_state_wrapper

FixedInitialStateWrapper (Wrapper)

Wrapper to make the initial state of the environment deterministic and fixed. This is done by removing the random noise from the DoF positions and velocities.

Source code in qdax/environments/init_state_wrapper.py
class FixedInitialStateWrapper(Wrapper):
    """Wrapper to make the initial state of the environment deterministic and fixed.
    This is done by removing the random noise from the DoF positions and velocities.
    """

    def __init__(
        self,
        env: Env,
        base_env_name: str,
        get_obs_fn: Optional[
            Callable[[brax.QP, brax.Info, jp.ndarray], jp.ndarray]
        ] = None,
    ):
        env_get_obs = {
            "ant": lambda qp, info, action: self._get_obs(qp, info),
            "halfcheetah": lambda qp, info, action: self._get_obs(qp, info),
            "walker2d": lambda qp, info, action: self._get_obs(qp),
            "hopper": lambda qp, info, action: self._get_obs(qp),
            "humanoid": lambda qp, info, action: self._get_obs(qp, info, action),
        }

        super().__init__(env)

        if get_obs_fn is not None:
            self._get_obs_fn = get_obs_fn
        elif base_env_name in env_get_obs.keys():
            self._get_obs_fn = env_get_obs[base_env_name]
        else:
            raise NotImplementedError(
                f"This wrapper does not support {base_env_name} yet."
            )

    def reset(self, rng: jp.ndarray) -> State:
        """Reset the state of the environment with a deterministic and fixed
        initial state.

        Args:
            rng: random key to handle stochastic operations. Used by the parent
                init reset function.

        Returns:
            A new state with a fixed observation.
        """
        # Run the default reset method of parent environment
        state = self.env.reset(rng)

        # Compute new initial positions and velocities
        qpos = self.sys.default_angle()
        qvel = jp.zeros((self.sys.num_joint_dof,))

        # update qd
        qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)

        # get the new obs
        obs = self._get_obs_fn(qp, self.sys.info(qp), jp.zeros(self.action_size))

        return state.replace(qp=qp, obs=obs)
reset(self, rng)

Reset the state of the environment with a deterministic and fixed initial state.

Parameters:
  • rng (Union[numpy.ndarray, jax.Array]) – random key to handle stochastic operations. Used by the parent init reset function.

Returns:
  • State – A new state with a fixed observation.

Source code in qdax/environments/init_state_wrapper.py
def reset(self, rng: jp.ndarray) -> State:
    """Reset the state of the environment with a deterministic and fixed
    initial state.

    Args:
        rng: random key to handle stochastic operations. Used by the parent
            init reset function.

    Returns:
        A new state with a fixed observation.
    """
    # Run the default reset method of parent environment
    state = self.env.reset(rng)

    # Compute new initial positions and velocities
    qpos = self.sys.default_angle()
    qvel = jp.zeros((self.sys.num_joint_dof,))

    # update qd
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)

    # get the new obs
    obs = self._get_obs_fn(qp, self.sys.info(qp), jp.zeros(self.action_size))

    return state.replace(qp=qp, obs=obs)

locomotion_wrappers

QDSystem (System)

Inheritance of brax physic system.

Work precisely the same but store some information from the physical simulation in the aux_info attribute.

This is used in FeetContactWrapper to get the feet contact of the robot with the ground.

Source code in qdax/environments/locomotion_wrappers.py
class QDSystem(System):
    """Inheritance of brax physic system.

    Work precisely the same but store some information from the physical
    simulation in the aux_info attribute.

    This is used in FeetContactWrapper to get the feet contact of the
    robot with the ground.
    """

    def __init__(
        self, config: config_pb2.Config, resource_paths: Optional[Sequence[str]] = None
    ):
        super().__init__(config, resource_paths=resource_paths)
        self.aux_info = None

    def step(self, qp: QP, act: jp.ndarray) -> Tuple[QP, Info]:
        qp, info = super().step(qp, act)
        self.aux_info = info
        return qp, info
step(self, qp, act)

Generic step function. Overridden with appropriate step at init.

Source code in qdax/environments/locomotion_wrappers.py
def step(self, qp: QP, act: jp.ndarray) -> Tuple[QP, Info]:
    qp, info = super().step(qp, act)
    self.aux_info = info
    return qp, info

FeetContactWrapper (QDEnv)

Wraps gym environments to add the feet contact data.

Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the feet_contact booleans in the information dictionary of the Brax.state.

The only supported envs at the moment are among the classic locomotion envs : Walker2D, Hopper, Ant, Bullet.

New locomotions envs can easily be added by adding the config name of the feet of the corresponding environment in the FEET_NAME dictionary.

Example :

from brax import envs
from brax import jumpy as jp

# choose in ["ant", "walker2d", "hopper", "halfcheetah"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = FeetContactWrapper(env, ENV_NAME)

state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
    action = jp.zeros((qd_env.action_size,))
    state = qd_env.step(state, action)

    # retrieve feet contact
    feet_contact = state.info["state_descriptor"]

    # do whatever you want with feet_contact
    print(f"Feet contact : {feet_contact}")
Source code in qdax/environments/locomotion_wrappers.py
class FeetContactWrapper(QDEnv):
    """Wraps gym environments to add the feet contact data.

    Utilisation is simple: create an environment with Brax, pass
    it to the wrapper with the name of the environment, and it will
    work like before and will simply add the feet_contact booleans in
    the information dictionary of the Brax.state.

    The only supported envs at the moment are among the classic
    locomotion envs : Walker2D, Hopper, Ant, Bullet.

    New locomotions envs can easily be added by adding the config name
    of the feet of the corresponding environment in the FEET_NAME dictionary.

    Example :

        from brax import envs
        from brax import jumpy as jp

        # choose in ["ant", "walker2d", "hopper", "halfcheetah"]
        ENV_NAME = "ant"
        env = envs.create(env_name=ENV_NAME)
        qd_env = FeetContactWrapper(env, ENV_NAME)

        state = qd_env.reset(rng=jp.random_prngkey(seed=0))
        for i in range(10):
            action = jp.zeros((qd_env.action_size,))
            state = qd_env.step(state, action)

            # retrieve feet contact
            feet_contact = state.info["state_descriptor"]

            # do whatever you want with feet_contact
            print(f"Feet contact : {feet_contact}")


    """

    def __init__(self, env: Env, env_name: str):

        if env_name not in FEET_NAMES.keys():
            raise NotImplementedError(f"This wrapper does not support {env_name} yet.")

        super().__init__(config=None)

        self.env = env
        self._env_name = env_name
        if hasattr(self.env, "sys"):
            self.env.sys = QDSystem(self.env.sys.config)

        self._feet_contact_idx = jp.array(
            [env.sys.body.index.get(name) for name in FEET_NAMES[env_name]]
        )

    @property
    def state_descriptor_length(self) -> int:
        return self.behavior_descriptor_length

    @property
    def state_descriptor_name(self) -> str:
        return "feet_contact"

    @property
    def state_descriptor_limits(self) -> Tuple[List, List]:
        return self.behavior_descriptor_limits

    @property
    def behavior_descriptor_length(self) -> int:
        return len(self._feet_contact_idx)

    @property
    def behavior_descriptor_limits(self) -> Tuple[List, List]:
        bd_length = self.behavior_descriptor_length
        return (jnp.zeros((bd_length,)), jnp.ones((bd_length,)))

    @property
    def name(self) -> str:
        return self._env_name

    def reset(self, rng: jp.ndarray) -> State:
        state = self.env.reset(rng)
        state.info["state_descriptor"] = self._get_feet_contact(
            self.env.sys.info(state.qp)
        )
        return state

    def step(self, state: State, action: jp.ndarray) -> State:
        state = self.env.step(state, action)
        state.info["state_descriptor"] = self._get_feet_contact(self.env.sys.aux_info)
        return state

    def _get_feet_contact(self, info: Info) -> jp.ndarray:
        contacts = info.contact.vel
        return jp.any(contacts[self._feet_contact_idx], axis=1).astype(jp.float32)

    @property
    def unwrapped(self) -> Env:
        return self.env.unwrapped

    def __getattr__(self, name: str) -> Any:
        if name == "__setstate__":
            raise AttributeError(name)
        return getattr(self.env, name)
reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/locomotion_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    state = self.env.reset(rng)
    state.info["state_descriptor"] = self._get_feet_contact(
        self.env.sys.info(state.qp)
    )
    return state
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state = self.env.step(state, action)
    state.info["state_descriptor"] = self._get_feet_contact(self.env.sys.aux_info)
    return state

XYPositionWrapper (QDEnv)

Wraps gym environments to add the position data.

Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply add the actual position in the information dictionary of the Brax.state.

One can also add values to clip the state descriptors.

The only supported envs at the moment are among the classic locomotion envs : Ant, Humanoid.

New locomotions envs can easily be added by adding the config name of the feet of the corresponding environment in the STATE_POSITION dictionary.

RMQ: this can be used with Hopper, Walker2d, Halfcheetah but it makes less sens as those are limited to one direction.

Example :

from brax import envs
from brax import jumpy as jp

# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = XYPositionWrapper(env, ENV_NAME)

state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
    action = jp.zeros((qd_env.action_size,))
    state = qd_env.step(state, action)

    # retrieve feet contact
    xy_position = state.info["xy_position"]

    # do whatever you want with xy_position
    print(f"xy position : {xy_position}")
Source code in qdax/environments/locomotion_wrappers.py
class XYPositionWrapper(QDEnv):
    """Wraps gym environments to add the position data.

    Utilisation is simple: create an environment with Brax, pass
    it to the wrapper with the name of the environment, and it will
    work like before and will simply add the actual position in
    the information dictionary of the Brax.state.

    One can also add values to clip the state descriptors.

    The only supported envs at the moment are among the classic
    locomotion envs : Ant, Humanoid.

    New locomotions envs can easily be added by adding the config name
    of the feet of the corresponding environment in the STATE_POSITION
    dictionary.

    RMQ: this can be used with Hopper, Walker2d, Halfcheetah but it makes
    less sens as those are limited to one direction.

    Example :

        from brax import envs
        from brax import jumpy as jp

        # choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
        ENV_NAME = "ant"
        env = envs.create(env_name=ENV_NAME)
        qd_env = XYPositionWrapper(env, ENV_NAME)

        state = qd_env.reset(rng=jp.random_prngkey(seed=0))
        for i in range(10):
            action = jp.zeros((qd_env.action_size,))
            state = qd_env.step(state, action)

            # retrieve feet contact
            xy_position = state.info["xy_position"]

            # do whatever you want with xy_position
            print(f"xy position : {xy_position}")


    """

    def __init__(
        self,
        env: Env,
        env_name: str,
        minval: Optional[List[float]] = None,
        maxval: Optional[List[float]] = None,
    ):
        if env_name not in COG_NAMES.keys():
            raise NotImplementedError(f"This wrapper does not support {env_name} yet.")

        super().__init__(config=None)

        self.env = env
        self._env_name = env_name
        if hasattr(self.env, "sys"):
            self._cog_idx = self.env.sys.body.index[COG_NAMES[env_name]]
        else:
            raise NotImplementedError(f"This wrapper does not support {env_name} yet.")

        if minval is None:
            minval = jnp.ones((2,)) * (-jnp.inf)

        if maxval is None:
            maxval = jnp.ones((2,)) * jnp.inf

        if len(minval) == 2 and len(maxval) == 2:
            self._minval = jnp.array(minval)
            self._maxval = jnp.array(maxval)
        else:
            raise NotImplementedError(
                "Please make sure to give two values for each limits."
            )

    @property
    def state_descriptor_length(self) -> int:
        return 2

    @property
    def state_descriptor_name(self) -> str:
        return "xy_position"

    @property
    def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return self._minval, self._maxval

    @property
    def behavior_descriptor_length(self) -> int:
        return self.state_descriptor_length

    @property
    def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return self.state_descriptor_limits

    @property
    def name(self) -> str:
        return self._env_name

    def reset(self, rng: jp.ndarray) -> State:
        state = self.env.reset(rng)
        state.info["state_descriptor"] = jnp.clip(
            state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
        )
        return state

    def step(self, state: State, action: jp.ndarray) -> State:
        state = self.env.step(state, action)
        # get xy position of the center of gravity
        state.info["state_descriptor"] = jnp.clip(
            state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
        )
        return state

    @property
    def unwrapped(self) -> Env:
        return self.env.unwrapped

    def __getattr__(self, name: str) -> Any:
        if name == "__setstate__":
            raise AttributeError(name)
        return getattr(self.env, name)
reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/locomotion_wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    state = self.env.reset(rng)
    state.info["state_descriptor"] = jnp.clip(
        state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
    )
    return state
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state = self.env.step(state, action)
    # get xy position of the center of gravity
    state.info["state_descriptor"] = jnp.clip(
        state.qp.pos[self._cog_idx][:2], a_min=self._minval, a_max=self._maxval
    )
    return state

NoForwardRewardWrapper (Wrapper)

Wraps gym environments to remove forward reward.

Utilisation is simple: create an environment with Brax, pass it to the wrapper with the name of the environment, and it will work like before and will simply remove the forward speed term of the reward.

Example :

from brax import envs
from brax import jumpy as jp

# choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
ENV_NAME = "ant"
env = envs.create(env_name=ENV_NAME)
qd_env = NoForwardRewardWrapper(env, ENV_NAME)

state = qd_env.reset(rng=jp.random_prngkey(seed=0))
for i in range(10):
    action = jp.zeros((qd_env.action_size,))
    state = qd_env.step(state, action)
Source code in qdax/environments/locomotion_wrappers.py
class NoForwardRewardWrapper(Wrapper):
    """Wraps gym environments to remove forward reward.

    Utilisation is simple: create an environment with Brax, pass
    it to the wrapper with the name of the environment, and it will
    work like before and will simply remove the forward speed term
    of the reward.

    Example :

        from brax import envs
        from brax import jumpy as jp

        # choose in ["ant", "walker2d", "hopper", "halfcheetah", "humanoid"]
        ENV_NAME = "ant"
        env = envs.create(env_name=ENV_NAME)
        qd_env = NoForwardRewardWrapper(env, ENV_NAME)

        state = qd_env.reset(rng=jp.random_prngkey(seed=0))
        for i in range(10):
            action = jp.zeros((qd_env.action_size,))
            state = qd_env.step(state, action)
    """

    def __init__(self, env: Env, env_name: str) -> None:
        if env_name not in FORWARD_REWARD_NAMES.keys():
            raise NotImplementedError(f"This wrapper does not support {env_name} yet.")
        super().__init__(env)
        self._env_name = env_name
        self._fd_reward_field = FORWARD_REWARD_NAMES[env_name]

    @property
    def name(self) -> str:
        return self._env_name

    def step(self, state: State, action: jp.ndarray) -> State:
        state = self.env.step(state, action)
        # update the reward (remove forward_reward)
        new_reward = state.reward - state.metrics[self._fd_reward_field]
        return state.replace(reward=new_reward)  # type: ignore
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/locomotion_wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state = self.env.step(state, action)
    # update the reward (remove forward_reward)
    new_reward = state.reward - state.metrics[self._fd_reward_field]
    return state.replace(reward=new_reward)  # type: ignore

pointmaze

PointMaze (Env)

Jax/Brax implementation of the PointMaze. Highly inspired from the old python implementation of the PointMaze.

In order to stay in the Brax API, I will use a fake QP at several moment of the implementation. This enable to use the brax.envs.State from Brax. To avoid this, it would be good to ask Brax to enlarge a bit their API for environments that are not physically simulated.

Source code in qdax/environments/pointmaze.py
class PointMaze(Env):
    """Jax/Brax implementation of the PointMaze.
    Highly inspired from the old python implementation of
    the PointMaze.

    In order to stay in the Brax API, I will use a fake QP
    at several moment of the implementation. This enable to
    use the brax.envs.State from Brax. To avoid this,
    it would be good to ask Brax to enlarge a bit their API
    for environments that are not physically simulated.
    """

    def __init__(
        self,
        scale_action_space: float = 10,
        x_min: float = -1,
        x_max: float = 1,
        y_min: float = -1,
        y_max: float = 1,
        zone_width: float = 0.1,
        zone_width_offset_from_x_min: float = 0.5,
        zone_height_offset_from_y_max: float = -0.2,
        wall_width_ratio: float = 0.75,
        upper_wall_height_offset: float = 0.2,
        lower_wall_height_offset: float = -0.5,
        **kwargs: Any,
    ) -> None:

        super().__init__(None, **kwargs)

        self._scale_action_space = scale_action_space
        self._x_min = x_min
        self._x_max = x_max
        self._y_min = y_min
        self._y_max = y_max

        self._low = jp.array([self._x_min, self._y_min], dtype=jp.float32)
        self._high = jp.array([self._x_max, self._y_max], dtype=jp.float32)

        self.n_zones = 1

        self.zone_width = zone_width

        self.zone_width_offset = self._x_min + zone_width_offset_from_x_min
        self.zone_height_offset = self._y_max + zone_height_offset_from_y_max

        self.viewer = None

        # Walls
        self.wallheight = 0.01
        self.wallwidth = (self._x_max - self._x_min) * wall_width_ratio

        self.upper_wall_width_offset = self._x_min + self.wallwidth / 2
        self.upper_wall_height_offset = upper_wall_height_offset

        self.lower_wall_width_offset = self._x_max - self.wallwidth / 2
        self.lower_wall_height_offset = lower_wall_height_offset

    @property
    def descriptors_min_values(self) -> List[float]:
        """Minimum values for descriptors."""
        return [self._x_min, self._y_min]

    @property
    def descriptors_max_values(self) -> List[float]:
        """Maximum values for descriptors."""
        return [self._x_max, self._y_max]

    @property
    def descriptors_names(self) -> List[str]:
        """Descriptors names."""
        return ["x_pos", "y_pos"]

    @property
    def state_descriptor_length(self) -> int:
        return 2

    @property
    def state_descriptor_name(self) -> str:
        return "xy_position"

    @property
    def state_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return [self._x_min, self._y_min], [self._x_max, self._y_max]

    @property
    def behavior_descriptor_length(self) -> int:
        return self.state_descriptor_length

    @property
    def behavior_descriptor_limits(self) -> Tuple[List[float], List[float]]:
        return self.state_descriptor_limits

    @property
    def action_size(self) -> int:
        """The size of the observation vector returned in step and reset."""
        return 2

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jp.random_split(rng, 3)
        # get initial position - reproduce the old implementation
        x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10
        y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7)
        obs_init = jp.array([x_init, y_init])
        # create fake qp (to re-use brax.State)
        fake_qp = brax.QP.zero()
        # init reward, metrics and infos
        reward, done = jp.zeros(2)
        metrics: Dict = {}
        # managing state descriptor by our own
        info_init = {"state_descriptor": obs_init}
        return State(fake_qp, obs_init, reward, done, metrics, info_init)

    def step(self, state: State, action: jp.ndarray) -> State:
        """Run one timestep of the environment's dynamics."""

        # clip action taken
        min_action = self._low
        max_action = self._high
        action = jp.clip(action, min_action, max_action) / self._scale_action_space

        # get the current position
        x_pos_old, y_pos_old = state.obs

        # compute the new position
        x_pos = x_pos_old + action[0]
        y_pos = y_pos_old + action[1]

        # take into account a potential wall collision
        y_pos = self._collision_lower_wall(y_pos, y_pos_old, x_pos, x_pos_old)
        y_pos = self._collision_upper_wall(y_pos, y_pos_old, x_pos, x_pos_old)

        # take into account border walls
        x_pos = jp.clip(x_pos, jp.array(self._x_min), jp.array(self._x_max))
        y_pos = jp.clip(y_pos, jp.array(self._y_min), jp.array(self._y_max))

        reward = -jp.norm(
            jp.array([x_pos - self.zone_width_offset, y_pos - self.zone_height_offset])
        )
        # determine if zone was reached
        in_zone = self._in_zone(x_pos, y_pos)

        done = jp.where(
            jp.array(in_zone),
            x=jp.array(1.0),
            y=jp.array(0.0),
        )

        new_obs = jp.array([x_pos, y_pos])

        # update state descriptor
        state.info["state_descriptor"] = new_obs
        # update the state
        return state.replace(obs=new_obs, reward=reward, done=done)  # type: ignore

    def _in_zone(self, x_pos: jp.ndarray, y_pos: jp.ndarray) -> Union[bool, jp.ndarray]:
        """Determine if the point reached the goal area."""
        zone_center_width, zone_center_height = (
            self.zone_width_offset,
            self.zone_height_offset,
        )

        condition_1 = zone_center_width - self.zone_width / 2 <= x_pos
        condition_2 = x_pos <= zone_center_width + self.zone_width / 2

        condition_3 = zone_center_height - self.zone_width / 2 <= y_pos
        condition_4 = y_pos <= zone_center_height + self.zone_width / 2

        return condition_1 & condition_2 & condition_3 & condition_4

    def _collision_lower_wall(
        self,
        y_pos: jp.ndarray,
        y_pos_old: jp.ndarray,
        x_pos: jp.ndarray,
        x_pos_old: jp.ndarray,
    ) -> jp.ndarray:
        """Manage potential collisions with the walls."""

        # global conditions on the x axis contacts
        x_hitting_wall = (self.lower_wall_height_offset - y_pos_old) / (
            y_pos - y_pos_old
        ) * (x_pos - x_pos_old) + x_pos_old
        x_axis_contact_condition = x_hitting_wall >= self._x_max - self.wallwidth

        # From down - boolean style
        y_axis_down_contact_condition_1 = y_pos_old <= self.lower_wall_height_offset
        y_axis_down_contact_condition_2 = self.lower_wall_height_offset < y_pos
        # y_pos update
        new_y_pos = jp.where(
            y_axis_down_contact_condition_1
            & y_axis_down_contact_condition_2
            & x_axis_contact_condition,
            x=jp.array(self.lower_wall_height_offset),
            y=y_pos,
        )

        # From up - boolean style
        y_axis_up_contact_condition_1 = (
            y_pos < self.lower_wall_height_offset + self.wallheight
        )
        y_axis_up_contact_condition_2 = (
            self.lower_wall_height_offset + self.wallheight <= y_pos_old
        )
        y_axis_up_contact_condition_3 = y_pos_old < self.upper_wall_height_offset
        # y_pos update
        new_y_pos = jp.where(
            y_axis_up_contact_condition_1
            & y_axis_up_contact_condition_2
            & y_axis_up_contact_condition_3
            & x_axis_contact_condition,
            x=jp.array(self.lower_wall_height_offset + self.wallheight),
            y=new_y_pos,
        )

        return new_y_pos

    def _collision_upper_wall(
        self,
        y_pos: jp.ndarray,
        y_pos_old: jp.ndarray,
        x_pos: jp.ndarray,
        x_pos_old: jp.ndarray,
    ) -> jp.ndarray:
        """Manage potential collisions with the walls."""

        # global conditions on the x axis contacts
        x_hitting_wall = (self.upper_wall_height_offset - y_pos_old) / (
            y_pos - y_pos_old
        ) * (x_pos - x_pos_old) + x_pos_old
        x_axis_contact_condition = x_hitting_wall <= self._x_min + self.wallwidth

        # From up - boolean style
        y_axis_up_contact_condition_1 = (
            y_pos_old >= self.upper_wall_height_offset + self.wallheight
        )
        y_axis_up_contact_condition_2 = (
            self.upper_wall_height_offset + self.wallheight > y_pos
        )
        # y_pos update
        new_y_pos = jp.where(
            y_axis_up_contact_condition_1
            & y_axis_up_contact_condition_2
            & x_axis_contact_condition,
            x=jp.array(self.upper_wall_height_offset + self.wallheight),
            y=y_pos,
        )

        # From down - boolean style
        y_axis_down_contact_condition_1 = y_pos > self.upper_wall_height_offset
        y_axis_down_contact_condition_2 = self.upper_wall_height_offset >= y_pos_old
        y_axis_down_contact_condition_3 = y_pos_old > self.lower_wall_height_offset
        # y_pos update
        new_y_pos = jp.where(
            y_axis_down_contact_condition_1
            & y_axis_down_contact_condition_2
            & y_axis_down_contact_condition_3
            & x_axis_contact_condition,
            x=jp.array(self.upper_wall_height_offset),
            y=new_y_pos,
        )

        return new_y_pos
descriptors_min_values: List[float] property readonly

Minimum values for descriptors.

descriptors_max_values: List[float] property readonly

Maximum values for descriptors.

descriptors_names: List[str] property readonly

Descriptors names.

action_size: int property readonly

The size of the observation vector returned in step and reset.

reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/pointmaze.py
def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jp.random_split(rng, 3)
    # get initial position - reproduce the old implementation
    x_init = jp.random_uniform(rng1, (), low=self._x_min, high=self._x_max) / 10
    y_init = jp.random_uniform(rng2, (), low=self._y_min, high=-0.7)
    obs_init = jp.array([x_init, y_init])
    # create fake qp (to re-use brax.State)
    fake_qp = brax.QP.zero()
    # init reward, metrics and infos
    reward, done = jp.zeros(2)
    metrics: Dict = {}
    # managing state descriptor by our own
    info_init = {"state_descriptor": obs_init}
    return State(fake_qp, obs_init, reward, done, metrics, info_init)
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/pointmaze.py
def step(self, state: State, action: jp.ndarray) -> State:
    """Run one timestep of the environment's dynamics."""

    # clip action taken
    min_action = self._low
    max_action = self._high
    action = jp.clip(action, min_action, max_action) / self._scale_action_space

    # get the current position
    x_pos_old, y_pos_old = state.obs

    # compute the new position
    x_pos = x_pos_old + action[0]
    y_pos = y_pos_old + action[1]

    # take into account a potential wall collision
    y_pos = self._collision_lower_wall(y_pos, y_pos_old, x_pos, x_pos_old)
    y_pos = self._collision_upper_wall(y_pos, y_pos_old, x_pos, x_pos_old)

    # take into account border walls
    x_pos = jp.clip(x_pos, jp.array(self._x_min), jp.array(self._x_max))
    y_pos = jp.clip(y_pos, jp.array(self._y_min), jp.array(self._y_max))

    reward = -jp.norm(
        jp.array([x_pos - self.zone_width_offset, y_pos - self.zone_height_offset])
    )
    # determine if zone was reached
    in_zone = self._in_zone(x_pos, y_pos)

    done = jp.where(
        jp.array(in_zone),
        x=jp.array(1.0),
        y=jp.array(0.0),
    )

    new_obs = jp.array([x_pos, y_pos])

    # update state descriptor
    state.info["state_descriptor"] = new_obs
    # update the state
    return state.replace(obs=new_obs, reward=reward, done=done)  # type: ignore

wrappers

CompletedEvalMetrics (PyTreeNode) dataclass

CompletedEvalMetrics(current_episode_metrics: Dict[str, Union[numpy.ndarray, jax.Array]], completed_episodes_metrics: Dict[str, Union[numpy.ndarray, jax.Array]], completed_episodes: Union[numpy.ndarray, jax.Array], completed_episodes_steps: Union[numpy.ndarray, jax.Array])

Source code in qdax/environments/wrappers.py
class CompletedEvalMetrics(flax.struct.PyTreeNode):
    current_episode_metrics: Dict[str, jp.ndarray]
    completed_episodes_metrics: Dict[str, jp.ndarray]
    completed_episodes: jp.ndarray
    completed_episodes_steps: jp.ndarray
replace(self, **updates)

"Returns a new object replacing the specified fields with new values.

Source code in qdax/environments/wrappers.py
def replace(self, **updates):
  """ "Returns a new object replacing the specified fields with new values."""
  return dataclasses.replace(self, **updates)

CompletedEvalWrapper (Wrapper)

Brax env with eval metrics for completed episodes.

Source code in qdax/environments/wrappers.py
class CompletedEvalWrapper(Wrapper):
    """Brax env with eval metrics for completed episodes."""

    STATE_INFO_KEY = "completed_eval_metrics"

    def reset(self, rng: jp.ndarray) -> State:
        reset_state = self.env.reset(rng)
        reset_state.metrics["reward"] = reset_state.reward
        eval_metrics = CompletedEvalMetrics(
            current_episode_metrics=jax.tree_util.tree_map(
                jp.zeros_like, reset_state.metrics
            ),
            completed_episodes_metrics=jax.tree_util.tree_map(
                lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics
            ),
            completed_episodes=jp.zeros(()),
            completed_episodes_steps=jp.zeros(()),
        )
        reset_state.info[self.STATE_INFO_KEY] = eval_metrics
        return reset_state

    def step(self, state: State, action: jp.ndarray) -> State:
        state_metrics = state.info[self.STATE_INFO_KEY]
        if not isinstance(state_metrics, CompletedEvalMetrics):
            raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
        del state.info[self.STATE_INFO_KEY]
        nstate = self.env.step(state, action)
        nstate.metrics["reward"] = nstate.reward
        # steps stores the highest step reached when done = True, and then
        # the next steps becomes action_repeat
        completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum(
            nstate.info["steps"] * nstate.done
        )
        current_episode_metrics = jax.tree_util.tree_map(
            lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics
        )
        completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done)
        completed_episodes_metrics = jax.tree_util.tree_map(
            lambda a, b: a + jp.sum(b * nstate.done),
            state_metrics.completed_episodes_metrics,
            current_episode_metrics,
        )
        current_episode_metrics = jax.tree_util.tree_map(
            lambda a, b: a * (1 - nstate.done) + b * nstate.done,
            current_episode_metrics,
            nstate.metrics,
        )

        eval_metrics = CompletedEvalMetrics(
            current_episode_metrics=current_episode_metrics,
            completed_episodes_metrics=completed_episodes_metrics,
            completed_episodes=completed_episodes,
            completed_episodes_steps=completed_episodes_steps,
        )
        nstate.info[self.STATE_INFO_KEY] = eval_metrics
        return nstate
reset(self, rng)

Resets the environment to an initial state.

Source code in qdax/environments/wrappers.py
def reset(self, rng: jp.ndarray) -> State:
    reset_state = self.env.reset(rng)
    reset_state.metrics["reward"] = reset_state.reward
    eval_metrics = CompletedEvalMetrics(
        current_episode_metrics=jax.tree_util.tree_map(
            jp.zeros_like, reset_state.metrics
        ),
        completed_episodes_metrics=jax.tree_util.tree_map(
            lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics
        ),
        completed_episodes=jp.zeros(()),
        completed_episodes_steps=jp.zeros(()),
    )
    reset_state.info[self.STATE_INFO_KEY] = eval_metrics
    return reset_state
step(self, state, action)

Run one timestep of the environment's dynamics.

Source code in qdax/environments/wrappers.py
def step(self, state: State, action: jp.ndarray) -> State:
    state_metrics = state.info[self.STATE_INFO_KEY]
    if not isinstance(state_metrics, CompletedEvalMetrics):
        raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
    del state.info[self.STATE_INFO_KEY]
    nstate = self.env.step(state, action)
    nstate.metrics["reward"] = nstate.reward
    # steps stores the highest step reached when done = True, and then
    # the next steps becomes action_repeat
    completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum(
        nstate.info["steps"] * nstate.done
    )
    current_episode_metrics = jax.tree_util.tree_map(
        lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics
    )
    completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done)
    completed_episodes_metrics = jax.tree_util.tree_map(
        lambda a, b: a + jp.sum(b * nstate.done),
        state_metrics.completed_episodes_metrics,
        current_episode_metrics,
    )
    current_episode_metrics = jax.tree_util.tree_map(
        lambda a, b: a * (1 - nstate.done) + b * nstate.done,
        current_episode_metrics,
        nstate.metrics,
    )

    eval_metrics = CompletedEvalMetrics(
        current_episode_metrics=current_episode_metrics,
        completed_episodes_metrics=completed_episodes_metrics,
        completed_episodes=completed_episodes,
        completed_episodes_steps=completed_episodes_steps,
    )
    nstate.info[self.STATE_INFO_KEY] = eval_metrics
    return nstate