Utils

qdax.utils special

metrics

Defines functions to retrieve metrics from training processes.

CSVLogger

Logger to save metrics of an experiment in a csv file during the training process.

Source code in qdax/utils/metrics.py
class CSVLogger:
    """Logger to save metrics of an experiment in a csv file
    during the training process.
    """

    def __init__(self, filename: str, header: List) -> None:
        """Create the csv logger, create a file and write the
        header.

        Args:
            filename: path to which the file will be saved.
            header: header of the csv file.
        """
        self._filename = filename
        self._header = header
        with open(self._filename, "w") as file:
            writer = csv.DictWriter(file, fieldnames=self._header)
            # write the header
            writer.writeheader()

    def log(self, metrics: Dict[str, float]) -> None:
        """Log new metrics to the csv file.

        Args:
            metrics: A dictionary containing the metrics that
                need to be saved.
        """
        with open(self._filename, "a") as file:
            writer = csv.DictWriter(file, fieldnames=self._header)
            # write new metrics in a raw
            writer.writerow(metrics)
__init__(self, filename, header) special

Create the csv logger, create a file and write the header.

Parameters:
  • filename (str) – path to which the file will be saved.

  • header (List) – header of the csv file.

Source code in qdax/utils/metrics.py
def __init__(self, filename: str, header: List) -> None:
    """Create the csv logger, create a file and write the
    header.

    Args:
        filename: path to which the file will be saved.
        header: header of the csv file.
    """
    self._filename = filename
    self._header = header
    with open(self._filename, "w") as file:
        writer = csv.DictWriter(file, fieldnames=self._header)
        # write the header
        writer.writeheader()
log(self, metrics)

Log new metrics to the csv file.

Parameters:
  • metrics (Dict[str, float]) – A dictionary containing the metrics that need to be saved.

Source code in qdax/utils/metrics.py
def log(self, metrics: Dict[str, float]) -> None:
    """Log new metrics to the csv file.

    Args:
        metrics: A dictionary containing the metrics that
            need to be saved.
    """
    with open(self._filename, "a") as file:
        writer = csv.DictWriter(file, fieldnames=self._header)
        # write new metrics in a raw
        writer.writerow(metrics)

default_ga_metrics(repertoire)

Compute the usual GA metrics that one can retrieve from a GA repertoire.

Parameters:
  • repertoire (GARepertoire) – a GA repertoire

Returns:
  • Metrics – a dictionary containing the max fitness of the repertoire.

Source code in qdax/utils/metrics.py
def default_ga_metrics(
    repertoire: GARepertoire,
) -> Metrics:
    """Compute the usual GA metrics that one can retrieve
    from a GA repertoire.

    Args:
        repertoire: a GA repertoire

    Returns:
        a dictionary containing the max fitness of the
            repertoire.
    """

    # get metrics
    max_fitness = jnp.max(repertoire.fitnesses, axis=0)

    return {
        "max_fitness": max_fitness,
    }

default_qd_metrics(repertoire, qd_offset)

Compute the usual QD metrics that one can retrieve from a MAP Elites repertoire.

Parameters:
  • repertoire (MapElitesRepertoire) – a MAP-Elites repertoire

  • qd_offset (float) – an offset used to ensure that the QD score will be positive and increasing with the number of individuals.

Returns:
  • Metrics – a dictionary containing the QD score (sum of fitnesses modified to be all positive), the max fitness of the repertoire, the coverage (number of niche filled in the repertoire).

Source code in qdax/utils/metrics.py
def default_qd_metrics(repertoire: MapElitesRepertoire, qd_offset: float) -> Metrics:
    """Compute the usual QD metrics that one can retrieve
    from a MAP Elites repertoire.

    Args:
        repertoire: a MAP-Elites repertoire
        qd_offset: an offset used to ensure that the QD score
            will be positive and increasing with the number
            of individuals.

    Returns:
        a dictionary containing the QD score (sum of fitnesses
            modified to be all positive), the max fitness of the
            repertoire, the coverage (number of niche filled in
            the repertoire).
    """

    # get metrics
    repertoire_empty = repertoire.fitnesses == -jnp.inf
    qd_score = jnp.sum(repertoire.fitnesses, where=~repertoire_empty)
    qd_score += qd_offset * jnp.sum(1.0 - repertoire_empty)
    coverage = 100 * jnp.mean(1.0 - repertoire_empty)
    max_fitness = jnp.max(repertoire.fitnesses)

    return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}

default_moqd_metrics(repertoire, reference_point)

Compute the MOQD metric given a MOME repertoire and a reference point.

Parameters:
  • repertoire (MOMERepertoire) – a MOME repertoire.

  • reference_point (jnp.ndarray) – the hypervolume of a pareto front has to be computed relatively to a reference point.

Returns:
  • Metrics – A dictionary containing all the computed metrics.

Source code in qdax/utils/metrics.py
def default_moqd_metrics(
    repertoire: MOMERepertoire, reference_point: jnp.ndarray
) -> Metrics:
    """Compute the MOQD metric given a MOME repertoire and a reference point.

    Args:
        repertoire: a MOME repertoire.
        reference_point: the hypervolume of a pareto front has to be computed
            relatively to a reference point.

    Returns:
        A dictionary containing all the computed metrics.
    """
    repertoire_empty = repertoire.fitnesses == -jnp.inf
    repertoire_empty = jnp.all(repertoire_empty, axis=-1)
    repertoire_not_empty = ~repertoire_empty
    repertoire_not_empty = jnp.any(repertoire_not_empty, axis=-1)
    coverage = 100 * jnp.mean(repertoire_not_empty)
    hypervolume_function = partial(compute_hypervolume, reference_point=reference_point)
    moqd_scores = jax.vmap(hypervolume_function)(repertoire.fitnesses)
    moqd_scores = jnp.where(repertoire_not_empty, moqd_scores, -jnp.inf)
    max_hypervolume = jnp.max(moqd_scores)
    max_scores = jnp.max(repertoire.fitnesses, axis=(0, 1))
    max_sum_scores = jnp.max(jnp.sum(repertoire.fitnesses, axis=-1), axis=(0, 1))
    num_solutions = jnp.sum(~repertoire_empty)
    (
        pareto_front,
        _,
    ) = repertoire.compute_global_pareto_front()

    global_hypervolume = compute_hypervolume(
        pareto_front, reference_point=reference_point
    )
    metrics = {
        "moqd_score": moqd_scores,
        "max_hypervolume": max_hypervolume,
        "max_scores": max_scores,
        "max_sum_scores": max_sum_scores,
        "coverage": coverage,
        "number_solutions": num_solutions,
        "global_hypervolume": global_hypervolume,
    }

    return metrics

pareto_front

Utils to handle pareto fronts.

compute_pareto_dominance(criteria_point, batch_of_criteria)

Returns if a point is pareto dominated given a set of points or not. We use maximization convention here.

criteria_point has shape (num_criteria,) batch_of_criteria has shape (num_points, num_criteria)

Parameters:
  • criteria_point (Array) – a vector of values.

  • batch_of_criteria (Array) – a batch of vector of values.

Returns:
  • Array – Return booleans when the vector is dominated by the batch.

Source code in qdax/utils/pareto_front.py
def compute_pareto_dominance(
    criteria_point: jnp.ndarray, batch_of_criteria: jnp.ndarray
) -> jnp.ndarray:
    """Returns if a point is pareto dominated given a set of points or not.
    We use maximization convention here.

    criteria_point has shape (num_criteria,)
    batch_of_criteria has shape (num_points, num_criteria)

    Args:
        criteria_point: a vector of values.
        batch_of_criteria: a batch of vector of values.

    Returns:
        Return booleans when the vector is dominated by the batch.
    """
    diff = jnp.subtract(batch_of_criteria, criteria_point)
    return jnp.any(jnp.all(diff > 0, axis=-1))

compute_pareto_front(batch_of_criteria)

Returns an array of boolean that states for each element if it is in the pareto front or not.

Parameters:
  • batch_of_criteria (Array) – a batch of points of shape (num_points, num_criteria)

Returns:
  • Array – An array of boolean with the boolean stating if each point is on the front or not.

Source code in qdax/utils/pareto_front.py
def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray:
    """Returns an array of boolean that states for each element if it is
    in the pareto front or not.

    Args:
        batch_of_criteria: a batch of points of shape (num_points, num_criteria)

    Returns:
        An array of boolean with the boolean stating if each point is on the
        front or not.
    """
    func = jax.vmap(lambda x: ~compute_pareto_dominance(x, batch_of_criteria))
    return func(batch_of_criteria)

compute_masked_pareto_dominance(criteria_point, batch_of_criteria, mask)

Returns if a point is pareto dominated given a set of points or not. We use maximization convention here.

This function is to be used with constant size batches of criteria, thus a mask is used to know which values are padded.

Parameters:
  • criteria_point (Array) – values to be evaluated, of shape (num_criteria,)

  • batch_of_criteria (Array) – set of points to compare with, of shape (batch_size, num_criteria)

  • mask (Array) – mask of shape (batch_size,), 1.0 where there is not element, 0 otherwise

Returns:
  • Array – Boolean assessing if the point is dominated or not.

Source code in qdax/utils/pareto_front.py
def compute_masked_pareto_dominance(
    criteria_point: jnp.ndarray, batch_of_criteria: jnp.ndarray, mask: Mask
) -> jnp.ndarray:
    """Returns if a point is pareto dominated given a set of points or not.
    We use maximization convention here.

    This function is to be used with constant size batches of criteria,
    thus a mask is used to know which values are padded.

    Args:
        criteria_point: values to be evaluated, of shape (num_criteria,)
        batch_of_criteria: set of points to compare with,
            of shape (batch_size, num_criteria)
        mask: mask of shape (batch_size,), 1.0 where there is not element,
            0 otherwise

    Returns:
        Boolean assessing if the point is dominated or not.
    """

    diff = jnp.subtract(batch_of_criteria, criteria_point)
    neutral_values = -jnp.ones_like(diff)
    diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)(
        neutral_values, diff
    )
    return jnp.any(jnp.all(diff > 0, axis=-1))

compute_masked_pareto_front(batch_of_criteria, mask)

Returns an array of boolean that states for each element if it is to be considered or not. This function is to be used with batches of constant size criteria, thus a mask is used to know which values are padded.

Parameters:
  • batch_of_criteria (Array) – data points considered

  • mask (Array) – mask to hide several points

Returns:
  • Array – An array of boolean stating the points to consider.

Source code in qdax/utils/pareto_front.py
def compute_masked_pareto_front(
    batch_of_criteria: jnp.ndarray, mask: Mask
) -> jnp.ndarray:
    """Returns an array of boolean that states for each element if it is to be
    considered or not. This function is to be used with batches of constant size
    criteria, thus a mask is used to know which values are padded.

    Args:
        batch_of_criteria: data points considered
        mask: mask to hide several points

    Returns:
        An array of boolean stating the points to consider.
    """
    func = jax.vmap(
        lambda x: ~compute_masked_pareto_dominance(x, batch_of_criteria, mask)
    )
    return func(batch_of_criteria) * ~mask

compute_hypervolume(pareto_front, reference_point)

Compute hypervolume of a pareto front.

Parameters:
  • pareto_front (qdax.types.ParetoFront[jax.Array]) – a pareto front, shape (pareto_size, num_objectives)

  • reference_point (Array) – a reference point to compute the volume, of shape (num_objectives,)

Returns:
  • Array – The hypervolume of the pareto front.

Source code in qdax/utils/pareto_front.py
def compute_hypervolume(
    pareto_front: ParetoFront[jnp.ndarray], reference_point: jnp.ndarray
) -> jnp.ndarray:
    """Compute hypervolume of a pareto front.

    Args:
        pareto_front: a pareto front, shape (pareto_size, num_objectives)
        reference_point: a reference point to compute the volume, of shape
            (num_objectives,)

    Returns:
        The hypervolume of the pareto front.
    """
    # check the number of objectives
    custom_message = (
        "Hypervolume calculation for more than" " 2 objectives not yet supported."
    )
    chex.assert_axis_dimension(
        tensor=pareto_front,
        axis=1,
        expected=2,
        custom_message=custom_message,
    )

    # concatenate the reference point to prepare for the area computation
    pareto_front = jnp.concatenate(  # type: ignore
        (pareto_front, jnp.expand_dims(reference_point, axis=0)), axis=0
    )
    # get ordered indices for the first objective
    idx = jnp.argsort(pareto_front[:, 0])
    # Note: this orders it in inversely for the second objective

    # create the mask - hide fake elements (those having -inf fitness)
    mask = pareto_front[idx, :] != -jnp.inf

    # sort the front and offset it with the ref point
    sorted_front = (pareto_front[idx, :] - reference_point) * mask

    # compute area rectangles between successive points
    sumdiff = (sorted_front[1:, 0] - sorted_front[:-1, 0]) * sorted_front[1:, 1]

    # remove the irrelevant values - where a mask was applied
    sumdiff = sumdiff * mask[:-1, 0]

    # get the hypervolume by summing the succcessives areas
    hypervolume = jnp.sum(sumdiff)

    return hypervolume

plotting

get_voronoi_finite_polygons_2d(centroids, radius=None)

Reconstruct infinite voronoi regions in a 2D diagram to finite regions.

Source code in qdax/utils/plotting.py
def get_voronoi_finite_polygons_2d(
    centroids: np.ndarray, radius: Optional[float] = None
) -> Tuple[List, np.ndarray]:
    """Reconstruct infinite voronoi regions in a 2D diagram to finite
    regions."""
    voronoi_diagram = Voronoi(centroids)
    if voronoi_diagram.points.shape[1] != 2:
        raise ValueError("Requires 2D input")

    new_regions = []
    new_vertices = voronoi_diagram.vertices.tolist()

    center = voronoi_diagram.points.mean(axis=0)
    if radius is None:
        radius = voronoi_diagram.points.ptp().max()

    # Construct a map containing all ridges for a given point
    all_ridges: Dict[jnp.ndarray, jnp.ndarray] = {}
    for (p1, p2), (v1, v2) in zip(
        voronoi_diagram.ridge_points, voronoi_diagram.ridge_vertices
    ):
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
        all_ridges.setdefault(p2, []).append((p1, v1, v2))

    # Reconstruct infinite regions
    for p1, region in enumerate(voronoi_diagram.point_region):
        vertices = voronoi_diagram.regions[region]

        if all(v >= 0 for v in vertices):
            # finite region
            new_regions.append(vertices)
            continue

        # reconstruct a non-finite region
        ridges = all_ridges[p1]
        new_region = [v for v in vertices if v >= 0]

        for p2, v1, v2 in ridges:
            if v2 < 0:
                v1, v2 = v2, v1
            if v1 >= 0:
                # finite ridge: already in the region
                continue

            # Compute the missing endpoint of an infinite ridge
            t = voronoi_diagram.points[p2] - voronoi_diagram.points[p1]  # tangent
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])  # normal

            midpoint = voronoi_diagram.points[[p1, p2]].mean(axis=0)
            direction = np.sign(np.dot(midpoint - center, n)) * n
            far_point = voronoi_diagram.vertices[v2] + direction * radius

            new_region.append(len(new_vertices))
            new_vertices.append(far_point.tolist())

        # sort region counterclockwise
        vs = np.asarray([new_vertices[v] for v in new_region])
        c = vs.mean(axis=0)
        angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
        new_region = np.array(new_region)[np.argsort(angles)]

        # finish
        new_regions.append(new_region.tolist())

    return new_regions, np.asarray(new_vertices)

plot_2d_map_elites_repertoire(centroids, repertoire_fitnesses, minval, maxval, repertoire_descriptors=None, ax=None, vmin=None, vmax=None)

Plot a visual representation of a 2d map elites repertoire.

TODO: Use repertoire as input directly. Because this function is very specific to repertoires.

Parameters:
  • centroids (Array) – the centroids of the repertoire

  • repertoire_fitnesses (Array) – the fitness of the repertoire

  • minval (Array) – minimum values for the descritors

  • maxval (Array) – maximum values for the descriptors

  • repertoire_descriptors (Optional[jax.Array]) – the descriptors. Defaults to None.

  • ax (Optional[matplotlib.axes._axes.Axes]) – a matplotlib axe for the figure to plot. Defaults to None.

  • vmin (Optional[float]) – minimum value for the fitness. Defaults to None. If not given, the value will be set to the minimum fitness in the repertoire.

  • vmax (Optional[float]) – maximum value for the fitness. Defaults to None. If not given, the value will be set to the maximum fitness in the repertoire.

Exceptions:
  • NotImplementedError – does not work for descriptors dimension different

Returns:
  • Tuple[Optional[matplotlib.figure.Figure], matplotlib.axes._axes.Axes] – A figure and axes object, corresponding to the visualisation of the repertoire.

Source code in qdax/utils/plotting.py
def plot_2d_map_elites_repertoire(
    centroids: jnp.ndarray,
    repertoire_fitnesses: jnp.ndarray,
    minval: jnp.ndarray,
    maxval: jnp.ndarray,
    repertoire_descriptors: Optional[jnp.ndarray] = None,
    ax: Optional[plt.Axes] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
) -> Tuple[Optional[Figure], Axes]:
    """Plot a visual representation of a 2d map elites repertoire.

    TODO: Use repertoire as input directly. Because this
    function is very specific to repertoires.

    Args:
        centroids: the centroids of the repertoire
        repertoire_fitnesses: the fitness of the repertoire
        minval: minimum values for the descritors
        maxval: maximum values for the descriptors
        repertoire_descriptors: the descriptors. Defaults to None.
        ax: a matplotlib axe for the figure to plot. Defaults to None.
        vmin: minimum value for the fitness. Defaults to None. If not given,
            the value will be set to the minimum fitness in the repertoire.
        vmax: maximum value for the fitness. Defaults to None. If not given,
            the value will be set to the maximum fitness in the repertoire.

    Raises:
        NotImplementedError: does not work for descriptors dimension different
        from 2.

    Returns:
        A figure and axes object, corresponding to the visualisation of the
        repertoire.
    """

    # TODO: check it and fix it if needed
    grid_empty = repertoire_fitnesses == -jnp.inf
    num_descriptors = centroids.shape[1]
    if num_descriptors != 2:
        raise NotImplementedError("Grid plot supports 2 descriptors only for now.")

    my_cmap = cm.viridis

    fitnesses = repertoire_fitnesses
    if vmin is None:
        vmin = float(jnp.min(fitnesses[~grid_empty]))
    if vmax is None:
        vmax = float(jnp.max(fitnesses[~grid_empty]))

    # set the parameters
    font_size = 12
    params = {
        "axes.labelsize": font_size,
        "legend.fontsize": font_size,
        "xtick.labelsize": font_size,
        "ytick.labelsize": font_size,
        "text.usetex": False,
        "figure.figsize": [10, 10],
    }

    mpl.rcParams.update(params)

    # create the plot object
    fig = None
    if ax is None:
        fig, ax = plt.subplots(facecolor="white", edgecolor="white")

    assert (
        len(np.array(minval).shape) < 2
    ), f"minval : {minval} should be float or couple of floats"
    assert (
        len(np.array(maxval).shape) < 2
    ), f"maxval : {maxval} should be float or couple of floats"

    if len(np.array(minval).shape) == 0 and len(np.array(maxval).shape) == 0:
        ax.set_xlim(minval, maxval)
        ax.set_ylim(minval, maxval)
    else:
        ax.set_xlim(minval[0], maxval[0])
        ax.set_ylim(minval[1], maxval[1])

    ax.set(adjustable="box", aspect="equal")

    # create the regions and vertices from centroids
    regions, vertices = get_voronoi_finite_polygons_2d(centroids)

    norm = Normalize(vmin=vmin, vmax=vmax)

    # fill the plot with contours
    for region in regions:
        polygon = vertices[region]
        ax.fill(*zip(*polygon), alpha=0.05, edgecolor="black", facecolor="white", lw=1)

    # fill the plot with the colors
    for idx, fitness in enumerate(fitnesses):
        if fitness > -jnp.inf:
            region = regions[idx]
            polygon = vertices[region]

            ax.fill(*zip(*polygon), alpha=0.8, color=my_cmap(norm(fitness)))

    # if descriptors are specified, add points location
    if repertoire_descriptors is not None:
        descriptors = repertoire_descriptors[~grid_empty]
        ax.scatter(
            descriptors[:, 0],
            descriptors[:, 1],
            c=fitnesses[~grid_empty],
            cmap=my_cmap,
            s=10,
            zorder=0,
        )

    # aesthetic
    ax.set_xlabel("Behavior Dimension 1")
    ax.set_ylabel("Behavior Dimension 2")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=my_cmap), cax=cax)
    cbar.ax.tick_params(labelsize=font_size)

    ax.set_title("MAP-Elites Grid")
    ax.set_aspect("equal")

    return fig, ax

plot_map_elites_results(env_steps, metrics, repertoire, min_bd, max_bd)

Plots three usual QD metrics, namely the coverage, the maximum fitness and the QD-score, along the number of environment steps. This function also plots a visualisation of the final map elites grid obtained. It ensures that those plots are aligned together to give a simple and efficient visualisation of an optimization process.

Parameters:
  • env_steps (Array) – the array containing the number of steps done in the environment.

  • metrics (Dict) – a dictionary containing metrics from the optimizatoin process.

  • repertoire (MapElitesRepertoire) – the final repertoire obtained.

  • min_bd (Array) – the mimimal possible values for the bd.

  • max_bd (Array) – the maximal possible values for the bd.

Returns:
  • Tuple[Optional[matplotlib.figure.Figure], matplotlib.axes._axes.Axes] – A figure and axes with the plots of the metrics and visualisation of the grid.

Source code in qdax/utils/plotting.py
def plot_map_elites_results(
    env_steps: jnp.ndarray,
    metrics: Dict,
    repertoire: MapElitesRepertoire,
    min_bd: jnp.ndarray,
    max_bd: jnp.ndarray,
) -> Tuple[Optional[Figure], Axes]:
    """Plots three usual QD metrics, namely the coverage, the maximum fitness
    and the QD-score, along the number of environment steps. This function also
    plots a visualisation of the final map elites grid obtained. It ensures that
    those plots are aligned together to give a simple and efficient visualisation
    of an optimization process.

    Args:
        env_steps: the array containing the number of steps done in the environment.
        metrics: a dictionary containing metrics from the optimizatoin process.
        repertoire: the final repertoire obtained.
        min_bd: the mimimal possible values for the bd.
        max_bd: the maximal possible values for the bd.

    Returns:
        A figure and axes with the plots of the metrics and visualisation of the grid.
    """
    # Customize matplotlib params
    font_size = 16
    params = {
        "axes.labelsize": font_size,
        "axes.titlesize": font_size,
        "legend.fontsize": font_size,
        "xtick.labelsize": font_size,
        "ytick.labelsize": font_size,
        "text.usetex": False,
        "axes.titlepad": 10,
    }

    mpl.rcParams.update(params)

    # Visualize the training evolution and final repertoire
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(40, 10))

    # env_steps = jnp.arange(num_iterations) * episode_length * batch_size

    axes[0].plot(env_steps, metrics["coverage"])
    axes[0].set_xlabel("Environment steps")
    axes[0].set_ylabel("Coverage in %")
    axes[0].set_title("Coverage evolution during training")
    axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable="box")

    axes[1].plot(env_steps, metrics["max_fitness"])
    axes[1].set_xlabel("Environment steps")
    axes[1].set_ylabel("Maximum fitness")
    axes[1].set_title("Maximum fitness evolution during training")
    axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable="box")

    axes[2].plot(env_steps, metrics["qd_score"])
    axes[2].set_xlabel("Environment steps")
    axes[2].set_ylabel("QD Score")
    axes[2].set_title("QD Score evolution during training")
    axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable="box")

    _, axes = plot_2d_map_elites_repertoire(
        centroids=repertoire.centroids,
        repertoire_fitnesses=repertoire.fitnesses,
        minval=min_bd,
        maxval=max_bd,
        repertoire_descriptors=repertoire.descriptors,
        ax=axes[3],
    )

    return fig, axes

multiline(xs, ys, c, ax=None, **kwargs)

Plot lines with different colorings (with c a container of numbers mapped to colormap)

Note

len(xs) == len(ys) == len© is the number of line segments len(xs[i]) == len(ys[i]) is the number of points for each line (indexed by i)

Parameters:
  • xs (Iterable) – First dimension of the trajectory.

  • ys (Iterable) – Second dimension of the trajectory.

  • c (Iterable) – Colors - one for each trajectory.

  • ax (Optional[matplotlib.axes._axes.Axes]) – A matplotlib axe. Defaults to None.

Returns:
  • LineCollection – Return a collection of lines corresponding to the trajectories.

Source code in qdax/utils/plotting.py
def multiline(
    xs: Iterable, ys: Iterable, c: Iterable, ax: Optional[Axes] = None, **kwargs: Any
) -> LineCollection:
    """Plot lines with different colorings (with c a container of numbers mapped to
        colormap)

    Note:
        len(xs) == len(ys) == len(c) is the number of line segments
        len(xs[i]) == len(ys[i]) is the number of points for each line (indexed by i)

    Args:
        xs: First dimension of the trajectory.
        ys: Second dimension of the trajectory.
        c: Colors - one for each trajectory.
        ax: A matplotlib axe. Defaults to None.

    Returns:
        Return a collection of lines corresponding to the trajectories.
    """

    # find axes
    ax = plt.gca() if ax is None else ax

    # create LineCollection
    segments = [np.column_stack([x, y]) for x, y in zip(xs, ys)]
    lc = LineCollection(segments, **kwargs)

    # set coloring of line segments
    # Note: error if c is given as a list here.
    lc.set_array(np.asarray(c))

    # add lines to axes and rescale
    # Note: adding a collection doesn't autoscalee xlim/ylim
    ax.add_collection(lc)
    ax.autoscale()
    return lc

plot_skills_trajectory(trajectories, skills, min_values, max_values)

Plots skills trajectories on a single plot with different colors to recognize the skills.

The plot can contain several trajectories of the same skill.

Parameters:
  • trajectories (Array) – skills trajectories

  • skills (Array) – skills corresponding to the given trajectories

  • min_values (Array) – minimum values that can be taken by the steps of the trajectory

  • max_values (Array) – maximum values that can be taken by the steps of the trajectory

Returns:
  • Tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] – A figure and axes.

Source code in qdax/utils/plotting.py
def plot_skills_trajectory(
    trajectories: jnp.ndarray,
    skills: jnp.ndarray,
    min_values: jnp.ndarray,
    max_values: jnp.ndarray,
) -> Tuple[Figure, Axes]:
    """Plots skills trajectories on a single plot with
    different colors to recognize the skills.

    The plot can contain several trajectories of the same
    skill.

    Args:
        trajectories: skills trajectories
        skills: skills corresponding to the given trajectories
        min_values: minimum values that can be taken by the steps
            of the trajectory
        max_values: maximum values that can be taken by the steps
            of the trajectory

    Returns:
        A figure and axes.
    """
    # get number of skills
    num_skills = skills.shape[1]

    # create color from possible skills (indexed from 0 to num skills - 1)
    c = skills.argmax(axis=1)

    # create the figure
    fig, ax = plt.subplots()

    # get lines from the trajectories
    xs, ys = trajectories
    lc = multiline(xs=xs, ys=ys, c=c, ax=ax, cmap="gist_rainbow")

    # set aesthetics
    ax.set_ylim(min_values[1], max_values[1])
    ax.set_xlim(min_values[0], max_values[0])
    ax.set_xlabel("Behavior Dimension 1")
    ax.set_ylabel("Behavior Dimension 2")
    ax.set_aspect("equal")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    axcb = fig.colorbar(lc, cax=cax)
    axcb.set_ticks(np.arange(num_skills, dtype=int))
    ax.set_title("Skill trajectories")

    return fig, ax

plot_mome_pareto_fronts(centroids, repertoire, maxval, minval, axes=None, color_style='hsv', with_global=False)

Plot the pareto fronts from all cells of the mome repertoire.

Parameters:
  • centroids (Array) – centroids of the repertoire

  • repertoire (MOMERepertoire) – mome repertoire

  • maxval (float) – maximum values for the descriptors

  • minval (float) – minimum values for the descriptors

  • axes (Optional[matplotlib.axes._axes.Axes]) – matplotlib axes. Defaults to None.

  • color_style (Optional[str]) – style of the colors used. Defaults to "hsv".

  • with_global (Optional[bool]) – plot the global pareto front in addition. Defaults to False.

Returns:
  • Axes – Returns the axes object with the plot.

Source code in qdax/utils/plotting.py
def plot_mome_pareto_fronts(
    centroids: jnp.ndarray,
    repertoire: MOMERepertoire,
    maxval: float,
    minval: float,
    axes: Optional[plt.Axes] = None,
    color_style: Optional[str] = "hsv",
    with_global: Optional[bool] = False,
) -> plt.Axes:
    """Plot the pareto fronts from all cells of the mome repertoire.

    Args:
        centroids: centroids of the repertoire
        repertoire: mome repertoire
        maxval: maximum values for the descriptors
        minval: minimum values for the descriptors
        axes: matplotlib axes. Defaults to None.
        color_style: style of the colors used. Defaults to "hsv".
        with_global: plot the global pareto front in addition.
            Defaults to False.

    Returns:
        Returns the axes object with the plot.
    """
    fitnesses = repertoire.fitnesses
    repertoire_descriptors = repertoire.descriptors

    assert fitnesses.shape[-1] == repertoire_descriptors.shape[-1] == 2
    assert color_style in ["hsv", "spectral"], "color_style must be hsv or spectral"

    num_centroids = len(centroids)
    grid_empty = jnp.any(fitnesses == -jnp.inf, axis=-1)

    # Extract polar coordinates
    if color_style == "hsv":
        center = jnp.array([(maxval - minval) / 2] * centroids.shape[1])
        polars = jnp.stack(
            (
                jnp.sqrt((jnp.sum((centroids - center) ** 2, axis=-1)))
                / (maxval - minval)
                / jnp.sqrt(2),
                jnp.arctan((centroids - center)[:, 1] / (centroids - center)[:, 0]),
            ),
            axis=-1,
        )
    elif color_style == "spectral":
        cmap = cm.get_cmap("Spectral")

    if axes is None:
        _, axes = plt.subplots(ncols=2, figsize=(12, 6))

    for i in range(num_centroids):
        if jnp.sum(~grid_empty[i]) > 0:
            cell_scores = fitnesses[i][~grid_empty[i]]
            cell = repertoire_descriptors[i][~grid_empty[i]]
            if color_style == "hsv":
                color = vector_to_rgb(polars[i, 1], polars[i, 0])
            else:
                color = cmap((centroids[i, 0] - minval) / (maxval - minval))
            axes[0].plot(cell_scores[:, 0], cell_scores[:, 1], "o", color=color)

            axes[1].plot(cell[:, 0], cell[:, 1], "o", color=color)

    # create the regions and vertices from centroids
    regions, vertices = get_voronoi_finite_polygons_2d(centroids)

    # fill the plot with contours
    for region in regions:
        polygon = vertices[region]
        axes[1].fill(
            *zip(*polygon), alpha=0.2, edgecolor="black", facecolor="white", lw=1
        )
    axes[0].set_title("Fitness")
    axes[1].set_title("Descriptor")
    axes[1].set_xlim(minval, maxval)
    axes[1].set_ylim(minval, maxval)

    if with_global:
        global_pareto_front, pareto_bool = repertoire.compute_global_pareto_front()
        global_pareto_descriptors = jnp.concatenate(repertoire_descriptors)[pareto_bool]
        axes[0].scatter(
            global_pareto_front[:, 0],
            global_pareto_front[:, 1],
            marker="o",
            edgecolors="black",
            facecolors="none",
            zorder=3,
            label="Global Pareto Front",
        )
        sorted_index = jnp.argsort(global_pareto_front[:, 0])
        axes[0].plot(
            global_pareto_front[sorted_index, 0],
            global_pareto_front[sorted_index, 1],
            linestyle="--",
            linewidth=2,
            color="k",
            zorder=3,
        )
        axes[1].scatter(
            global_pareto_descriptors[:, 0],
            global_pareto_descriptors[:, 1],
            marker="o",
            edgecolors="black",
            facecolors="none",
            zorder=3,
            label="Global Pareto Descriptor",
        )

    return axes

vector_to_rgb(angle, absolute)

Returns a color based on polar coordinates.

Parameters:
  • angle (float) – a given angle

  • absolute (float) – a ref

Returns:
  • Any – An appropriate color.

Source code in qdax/utils/plotting.py
def vector_to_rgb(angle: float, absolute: float) -> Any:
    """Returns a color based on polar coordinates.

    Args:
        angle: a given angle
        absolute: a ref

    Returns:
        An appropriate color.
    """

    # normalize angle
    angle = angle % (2 * np.pi)
    if angle < 0:
        angle += 2 * np.pi

    # rise absolute to avoid black colours
    absolute = (absolute + 0.5) / 1.5

    return mpl.colors.hsv_to_rgb((angle / 2 / np.pi, 1, absolute))

plot_global_pareto_front(pareto_front, ax=None, label=None, color=None)

Plots the global Pareto Front.

Parameters:
  • pareto_front (Array) – a pareto front

  • ax (Optional[matplotlib.axes._axes.Axes]) – a matplotlib ax. Defaults to None.

  • label (Optional[str]) – a given label. Defaults to None.

  • color (Optional[str]) – a color for the plotted points. Defaults to None.

Returns:
  • Tuple[Optional[matplotlib.figure.Figure], matplotlib.axes._axes.Axes] – A figure and an axe.

Source code in qdax/utils/plotting.py
def plot_global_pareto_front(
    pareto_front: jnp.ndarray,
    ax: Optional[plt.Axes] = None,
    label: Optional[str] = None,
    color: Optional[str] = None,
) -> Tuple[Optional[Figure], plt.Axes]:
    """Plots the global Pareto Front.

    Args:
        pareto_front: a pareto front
        ax: a matplotlib ax. Defaults to None.
        label: a given label. Defaults to None.
        color: a color for the plotted points. Defaults to None.

    Returns:
        A figure and an axe.
    """
    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.scatter(pareto_front[:, 0], pareto_front[:, 1], color=color, label=label)
        return fig, ax
    else:
        ax.scatter(pareto_front[:, 0], pareto_front[:, 1], color=color, label=label)

    return fig, ax

plot_multidimensional_map_elites_grid(repertoire, minval, maxval, grid_shape, ax=None, vmin=None, vmax=None)

Plot a visual 2D representation of a multidimensional MAP-Elites repertoire (where the dimensionality of descriptors can be greater than 2).

Parameters:
  • repertoire (MapElitesRepertoire) – the MAP-Elites repertoire to plot.

  • minval (Array) – minimum values for the descriptors

  • maxval (Array) – maximum values for the descriptors

  • grid_shape (Tuple[int, ...]) – the resolution of the grid.

  • ax (Optional[matplotlib.axes._axes.Axes]) – a matplotlib axe for the figure to plot. Defaults to None.

  • vmin (Optional[float]) – minimum value for the fitness. Defaults to None.

  • vmax (Optional[float]) – maximum value for the fitness. Defaults to None.

Exceptions:
  • ValueError – the resolution should be an int or a tuple

Returns:
  • Tuple[Optional[matplotlib.figure.Figure], matplotlib.axes._axes.Axes] – A figure and axes object, corresponding to the visualisation of the repertoire.

Source code in qdax/utils/plotting.py
def plot_multidimensional_map_elites_grid(
    repertoire: MapElitesRepertoire,
    minval: jnp.ndarray,
    maxval: jnp.ndarray,
    grid_shape: Tuple[int, ...],
    ax: Optional[plt.Axes] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
) -> Tuple[Optional[Figure], Axes]:
    """Plot a visual 2D representation of a multidimensional MAP-Elites repertoire
    (where the dimensionality of descriptors can be greater than 2).
    Args:
        repertoire: the MAP-Elites repertoire to plot.
        minval: minimum values for the descriptors
        maxval: maximum values for the descriptors
        grid_shape: the resolution of the grid.
        ax: a matplotlib axe for the figure to plot. Defaults to None.
        vmin: minimum value for the fitness. Defaults to None.
        vmax: maximum value for the fitness. Defaults to None.
    Raises:
        ValueError: the resolution should be an int or a tuple
    Returns:
        A figure and axes object, corresponding to the visualisation of the
        repertoire.
    """

    descriptors = repertoire.descriptors
    fitnesses = repertoire.fitnesses

    is_grid_empty = fitnesses.ravel() == -jnp.inf
    num_descriptors = descriptors.shape[1]

    if isinstance(grid_shape, tuple):
        assert (
            len(grid_shape) == num_descriptors
        ), "grid_shape should have the same length as num_descriptors"
    else:
        raise ValueError("resolution should be a tuple")

    assert np.size(minval) == num_descriptors or np.size(minval) == 1, (
        f"minval : {minval} should either be of size 1 "
        f"or have the same size as the number of descriptors: {num_descriptors}"
    )
    assert np.size(maxval) == num_descriptors or np.size(maxval) == 1, (
        f"maxval : {maxval} should either be of size 1 "
        f"or have the same size as the number of descriptors: {num_descriptors}"
    )

    non_empty_descriptors = descriptors[~is_grid_empty]
    non_empty_fitnesses = fitnesses[~is_grid_empty]

    # convert the descriptors to integer coordinates, depending on the resolution.
    resolutions_array = jnp.array(grid_shape)
    descriptors_integers = jnp.asarray(
        jnp.floor(
            resolutions_array * (non_empty_descriptors - minval) / (maxval - minval)
        ),
        dtype=jnp.int32,
    )

    # total number of grid cells along each dimension of the grid
    size_grid_x = np.prod(np.array(grid_shape[0::2]))
    # convert to int for the 1d case - if not, value 1.0 is given
    size_grid_y = np.prod(np.array(grid_shape[1::2]), dtype=int)

    # initialise the grid
    grid_2d = np.full(
        (size_grid_x.item(), size_grid_y.item()),
        fill_value=jnp.nan,
    )

    # put solutions in the grid according to their projected 2-dimensional coordinates
    for desc, fit in zip(descriptors_integers, non_empty_fitnesses):
        projection_2d = _get_projection_in_2d(desc, grid_shape)
        if jnp.isnan(grid_2d[projection_2d]) or fit.item() > grid_2d[projection_2d]:
            grid_2d[projection_2d] = fit.item()

    # set plot parameters
    font_size = 12
    params = {
        "axes.labelsize": font_size,
        "legend.fontsize": font_size,
        "xtick.labelsize": font_size,
        "ytick.labelsize": font_size,
        "text.usetex": False,
        "figure.figsize": [10, 10],
    }

    mpl.rcParams.update(params)

    # create the plot object
    fig = None
    if ax is None:
        fig, ax = plt.subplots(facecolor="white", edgecolor="white")
    ax.set(adjustable="box", aspect="equal")

    my_cmap = cm.viridis

    if vmin is None:
        vmin = float(jnp.min(non_empty_fitnesses))
    if vmax is None:
        vmax = float(jnp.max(non_empty_fitnesses))

    ax.imshow(
        grid_2d.T,
        origin="lower",
        aspect="equal",
        vmin=vmin,
        vmax=vmax,
        cmap=my_cmap,
    )

    # aesthetic
    ax.set_xlabel("Behavior Dimension 1")
    ax.set_ylabel("Behavior Dimension 2")
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    norm = Normalize(vmin=vmin, vmax=vmax)
    cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=my_cmap), cax=cax)
    cbar.ax.tick_params(labelsize=font_size)

    ax.set_title("MAP-Elites Grid")
    ax.set_aspect("equal")

    def _get_ticks_positions(
        total_size_grid_axis: int, step_ticks_on_axis: int
    ) -> jnp.ndarray:
        """
        Get the positions of the ticks on the grid axis.
        Args:
            total_size_grid_axis: total size of the grid axis
            step_ticks_on_axis: step of the ticks
        Returns:
            The positions of the ticks on the plot.
        """
        return np.arange(0, total_size_grid_axis + 1, step_ticks_on_axis) - 0.5

    # Ticks position
    major_ticks_x = _get_ticks_positions(
        size_grid_x.item(), step_ticks_on_axis=np.prod(grid_shape[2::2]).item()
    )
    minor_ticks_x = _get_ticks_positions(
        size_grid_x.item(), step_ticks_on_axis=np.prod(grid_shape[4::2]).item()
    )
    major_ticks_y = _get_ticks_positions(
        size_grid_y.item(), step_ticks_on_axis=np.prod(grid_shape[3::2]).item()
    )
    minor_ticks_y = _get_ticks_positions(
        size_grid_y.item(), step_ticks_on_axis=np.prod(grid_shape[5::2]).item()
    )

    ax.set_xticks(
        major_ticks_x,
    )
    ax.set_xticks(
        minor_ticks_x,
        minor=True,
    )
    ax.set_yticks(
        major_ticks_y,
    )
    ax.set_yticks(
        minor_ticks_y,
        minor=True,
    )

    # Ticks aesthetics
    ax.tick_params(
        which="minor",
        color="gray",
        labelcolor="gray",
        size=5,
    )
    ax.tick_params(
        which="major",
        labelsize=font_size,
        size=7,
    )

    ax.grid(which="minor", alpha=1.0, color="#000000", linewidth=0.5)
    if len(grid_shape) > 2:
        ax.grid(which="major", alpha=1.0, color="#000000", linewidth=2.5)

    def _get_positions_labels(
        _minval: float, _maxval: float, _number_ticks: int, _step_labels_ticks: int
    ) -> List[str]:
        positions = jnp.linspace(_minval, _maxval, num=_number_ticks)

        list_str_positions = []
        for index_tick, position in enumerate(positions):
            if index_tick % _step_labels_ticks != 0:
                character = ""
            else:
                character = f"{position:.2E}"
            list_str_positions.append(character)
        # forcing the last tick label
        list_str_positions[-1] = f"{positions[-1]:.2E}"
        return list_str_positions

    number_label_ticks = 4

    if len(major_ticks_x) // number_label_ticks > 0:
        ax.set_xticklabels(
            _get_positions_labels(
                _minval=minval[0],
                _maxval=maxval[0],
                _number_ticks=len(major_ticks_x),
                _step_labels_ticks=len(major_ticks_x) // number_label_ticks,
            )
        )
    if len(major_ticks_y) // number_label_ticks > 0:
        ax.set_yticklabels(
            _get_positions_labels(
                _minval=minval[1],
                _maxval=maxval[1],
                _number_ticks=len(major_ticks_y),
                _step_labels_ticks=len(major_ticks_y) // number_label_ticks,
            )
        )

    return fig, ax

sampling

Core components of the MAP-Elites-sampling algorithm.

average(quantities)

Default expectation extractor using average.

Source code in qdax/utils/sampling.py
@jax.jit
def average(quantities: jnp.ndarray) -> jnp.ndarray:
    """Default expectation extractor using average."""
    return jnp.average(quantities, axis=1)

median(quantities)

Alternative expectation extractor using median. More robust to outliers than average.

Source code in qdax/utils/sampling.py
@jax.jit
def median(quantities: jnp.ndarray) -> jnp.ndarray:
    """Alternative expectation extractor using median.
    More robust to outliers than average."""
    return jnp.median(quantities, axis=1)

mode(quantities)

Alternative expectation extractor using mode. More robust to outliers than average. WARNING: for multidimensional objects such as descriptor, do dimension-wise selection.

Source code in qdax/utils/sampling.py
@jax.jit
def mode(quantities: jnp.ndarray) -> jnp.ndarray:
    """Alternative expectation extractor using mode.
    More robust to outliers than average.
    WARNING: for multidimensional objects such as descriptor, do
    dimension-wise selection.
    """

    def _mode(quantity: jnp.ndarray) -> jnp.ndarray:

        # Ensure correct dimensions for both single and multi-dimension
        quantity = jnp.reshape(quantity, (quantity.shape[0], -1))

        # Dimension-wise voting in case of multi-dimension
        def _dim_mode(dim_quantity: jnp.ndarray) -> jnp.ndarray:
            unique_vals, counts = jnp.unique(
                dim_quantity, return_counts=True, size=dim_quantity.size
            )
            return unique_vals[jnp.argmax(counts)]

        # vmap over dimensions
        return jnp.squeeze(jax.vmap(_dim_mode)(jnp.transpose(quantity)))

    # vmap over individuals
    return jax.vmap(_mode)(quantities)

closest(quantities)

Alternative expectation extractor selecting individual that has the minimum distance to all other individuals. This is an approximation of the geometric median. More robust to outliers than average.

Source code in qdax/utils/sampling.py
@jax.jit
def closest(quantities: jnp.ndarray) -> jnp.ndarray:
    """Alternative expectation extractor selecting individual
    that has the minimum distance to all other individuals. This
    is an approximation of the geometric median.
    More robust to outliers than average."""

    def _closest(values: jnp.ndarray) -> jnp.ndarray:
        def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
            return jnp.sqrt(jnp.sum(jnp.square(x - y)))

        distances = jax.vmap(
            jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
        )(values, values)
        return values[jnp.argmin(jnp.mean(distances, axis=0))]

    return jax.vmap(_closest)(quantities)

std(quantities)

Default reproducibility extractor using standard deviation.

Source code in qdax/utils/sampling.py
@jax.jit
def std(quantities: jnp.ndarray) -> jnp.ndarray:
    """Default reproducibility extractor using standard deviation."""
    return jnp.std(quantities, axis=1)

mad(quantities)

Alternative reproducibility extractor using Median Absolute Deviation. More robust to outliers than standard deviation.

Source code in qdax/utils/sampling.py
@jax.jit
def mad(quantities: jnp.ndarray) -> jnp.ndarray:
    """Alternative reproducibility extractor using Median Absolute Deviation.
    More robust to outliers than standard deviation."""
    num_samples = quantities.shape[1]
    median = jnp.repeat(
        jnp.median(quantities, axis=1, keepdims=True), num_samples, axis=1
    )
    return jnp.median(jnp.abs(quantities - median), axis=1)

iqr(quantities)

Alternative reproducibility extractor using Inter-Quartile Range. More robust to outliers than standard deviation.

Source code in qdax/utils/sampling.py
@jax.jit
def iqr(quantities: jnp.ndarray) -> jnp.ndarray:
    """Alternative reproducibility extractor using Inter-Quartile Range.
    More robust to outliers than standard deviation."""
    q1 = jnp.quantile(quantities, 0.25, axis=1)
    q4 = jnp.quantile(quantities, 0.75, axis=1)
    return q4 - q1

dummy_extra_scores_extractor(extra_scores, num_samples)

Extract the final extra scores of a policy from multiple samples of the same policy in the environment. This Dummy implementation just return the full concatenate extra_score of all samples without extra computation.

Parameters:
  • extra_scores (Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]) – extra scores of the samples

  • num_samples (int) – the number of samples used

Returns:
  • Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]] – the new extra scores after extraction

Source code in qdax/utils/sampling.py
@partial(jax.jit, static_argnames=("num_samples",))
def dummy_extra_scores_extractor(
    extra_scores: ExtraScores,
    num_samples: int,
) -> ExtraScores:
    """
    Extract the final extra scores of a policy from multiple samples of
    the same policy in the environment.
    This Dummy implementation just return the full concatenate extra_score
    of all samples without extra computation.

    Args:
        extra_scores: extra scores of the samples
        num_samples: the number of samples used

    Returns:
        the new extra scores after extraction
    """
    return extra_scores

multi_sample_scoring_function(policies_params, random_key, scoring_fn, num_samples)

Wrap scoring_function to perform sampling.

This function returns the fitnesses, descriptors, and extra_scores computed over num_samples evaluations with the scoring_fn.

Parameters:
  • policies_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – policies to evaluate

  • random_key (Array) – JAX random key

  • scoring_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array]]) – scoring function used for evaluation

  • num_samples (int) – number of samples to generate for each individual

Returns:
  • Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array] – (n, num_samples) array of fitnesses, (n, num_samples, num_descriptors) array of descriptors, dict with num_samples extra_scores per individual, JAX random key

Source code in qdax/utils/sampling.py
@partial(
    jax.jit,
    static_argnames=(
        "scoring_fn",
        "num_samples",
    ),
)
def multi_sample_scoring_function(
    policies_params: Genotype,
    random_key: RNGKey,
    scoring_fn: Callable[
        [Genotype, RNGKey],
        Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
    ],
    num_samples: int,
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
    """
    Wrap scoring_function to perform sampling.

    This function returns the fitnesses, descriptors, and extra_scores computed
    over num_samples evaluations with the scoring_fn.

    Args:
        policies_params: policies to evaluate
        random_key: JAX random key
        scoring_fn: scoring function used for evaluation
        num_samples: number of samples to generate for each individual

    Returns:
        (n, num_samples) array of fitnesses,
        (n, num_samples, num_descriptors) array of descriptors,
        dict with num_samples extra_scores per individual,
        JAX random key
    """

    random_key, subkey = jax.random.split(random_key)
    keys = jax.random.split(subkey, num=num_samples)

    # evaluate
    sample_scoring_fn = jax.vmap(
        scoring_fn,
        # vectorizing over axis 0 vectorizes over the num_samples random keys
        in_axes=(None, 0),
        # indicates that the vectorized axis will become axis 1, i.e., the final
        # output is shape (batch_size, num_samples, ...)
        out_axes=1,
    )
    all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn(
        policies_params, keys
    )

    return all_fitnesses, all_descriptors, all_extra_scores, random_key

sampling(policies_params, random_key, scoring_fn, num_samples, extra_scores_extractor=<PjitFunction of <function dummy_extra_scores_extractor at 0x7f8b72fde280>>, fitness_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, descriptor_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>)

Wrap scoring_function to perform sampling.

This function return the expected fitnesses and descriptors for each individual over num_samples evaluations using the provided extractor function for the fitness and the descriptor.

Parameters:
  • policies_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – policies to evaluate

  • random_key (Array) – JAX random key

  • scoring_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array]]) – scoring function used for evaluation

  • num_samples (int) – number of samples to generate for each individual

  • extra_scores_extractor (Callable[[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], int], Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]) – function to extract the extra_scores from multiple samples of the same policy.

  • fitness_extractor (Callable[[jax.Array], jax.Array]) – function to extract the fitness expectation from multiple samples of the same policy.

  • descriptor_extractor (Callable[[jax.Array], jax.Array]) – function to extract the descriptor expectation from multiple samples of the same policy.

Returns:
  • Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array] – The expected fitnesses, descriptors and extra_scores of the individuals A new random key

Source code in qdax/utils/sampling.py
@partial(
    jax.jit,
    static_argnames=(
        "scoring_fn",
        "num_samples",
        "extra_scores_extractor",
        "fitness_extractor",
        "descriptor_extractor",
    ),
)
def sampling(
    policies_params: Genotype,
    random_key: RNGKey,
    scoring_fn: Callable[
        [Genotype, RNGKey],
        Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
    ],
    num_samples: int,
    extra_scores_extractor: Callable[
        [ExtraScores, int], ExtraScores
    ] = dummy_extra_scores_extractor,
    fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
    descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
    """Wrap scoring_function to perform sampling.

    This function return the expected fitnesses and descriptors for each
    individual over `num_samples` evaluations using the provided extractor
    function for the fitness and the descriptor.

    Args:
        policies_params: policies to evaluate
        random_key: JAX random key
        scoring_fn: scoring function used for evaluation
        num_samples: number of samples to generate for each individual
        extra_scores_extractor: function to extract the extra_scores from
            multiple samples of the same policy.
        fitness_extractor: function to extract the fitness expectation from
            multiple samples of the same policy.
        descriptor_extractor: function to extract the descriptor expectation
            from multiple samples of the same policy.

    Returns:
        The expected fitnesses, descriptors and extra_scores of the individuals
        A new random key
    """

    # Perform sampling
    (
        all_fitnesses,
        all_descriptors,
        all_extra_scores,
        random_key,
    ) = multi_sample_scoring_function(
        policies_params, random_key, scoring_fn, num_samples
    )

    # Extract final scores
    descriptors = descriptor_extractor(all_descriptors)
    fitnesses = fitness_extractor(all_fitnesses)
    extra_scores = extra_scores_extractor(all_extra_scores, num_samples)

    return fitnesses, descriptors, extra_scores, random_key

sampling_reproducibility(policies_params, random_key, scoring_fn, num_samples, extra_scores_extractor=<PjitFunction of <function dummy_extra_scores_extractor at 0x7f8b72fde280>>, fitness_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, descriptor_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, fitness_reproducibility_extractor=<PjitFunction of <function std at 0x7f8b72fdd670>>, descriptor_reproducibility_extractor=<PjitFunction of <function std at 0x7f8b72fdd670>>)

Wrap scoring_function to perform sampling and compute the expectation and reproduciblity.

This function return the reproducibility of fitnesses and descriptors for each individual over num_samples evaluations using the provided extractor function for the fitness and the descriptor.

Parameters:
  • policies_params (Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]) – policies to evaluate

  • random_key (Array) – JAX random key

  • scoring_fn (Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]], jax.Array], Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array]]) – scoring function used for evaluation

  • num_samples (int) – number of samples to generate for each individual

  • extra_scores_extractor (Callable[[Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], int], Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]]]) – function to extract the extra_scores from multiple samples of the same policy.

  • fitness_extractor (Callable[[jax.Array], jax.Array]) – function to extract the fitness expectation from multiple samples of the same policy.

  • descriptor_extractor (Callable[[jax.Array], jax.Array]) – function to extract the descriptor expectation from multiple samples of the same policy.

  • fitness_reproducibility_extractor (Callable[[jax.Array], jax.Array]) – function to extract the fitness reproducibility from multiple samples of the same policy.

  • descriptor_reproducibility_extractor (Callable[[jax.Array], jax.Array]) – function to extract the descriptor reproducibility from multiple samples of the same policy.

Returns:
  • Tuple[jax.Array, jax.Array, Dict[str, Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]], jax.Array, jax.Array, jax.Array] – The expected fitnesses, descriptors and extra_scores of the individuals The fitnesses and descriptors reproducibility of the individuals A new random key

Source code in qdax/utils/sampling.py
@partial(
    jax.jit,
    static_argnames=(
        "scoring_fn",
        "num_samples",
        "extra_scores_extractor",
        "fitness_extractor",
        "descriptor_extractor",
        "fitness_reproducibility_extractor",
        "descriptor_reproducibility_extractor",
    ),
)
def sampling_reproducibility(
    policies_params: Genotype,
    random_key: RNGKey,
    scoring_fn: Callable[
        [Genotype, RNGKey],
        Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
    ],
    num_samples: int,
    extra_scores_extractor: Callable[
        [ExtraScores, int], ExtraScores
    ] = dummy_extra_scores_extractor,
    fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
    descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
    fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std,
    descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std,
) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]:
    """Wrap scoring_function to perform sampling and compute the
    expectation and reproduciblity.

    This function return the reproducibility of fitnesses and descriptors for each
    individual over `num_samples` evaluations using the provided extractor
    function for the fitness and the descriptor.

    Args:
        policies_params: policies to evaluate
        random_key: JAX random key
        scoring_fn: scoring function used for evaluation
        num_samples: number of samples to generate for each individual
        extra_scores_extractor: function to extract the extra_scores from
            multiple samples of the same policy.
        fitness_extractor: function to extract the fitness expectation from
            multiple samples of the same policy.
        descriptor_extractor: function to extract the descriptor expectation
            from multiple samples of the same policy.
        fitness_reproducibility_extractor: function to extract the fitness
            reproducibility from multiple samples of the same policy.
        descriptor_reproducibility_extractor: function to extract the descriptor
            reproducibility from multiple samples of the same policy.

    Returns:
        The expected fitnesses, descriptors and extra_scores of the individuals
        The fitnesses and descriptors reproducibility of the individuals
        A new random key
    """

    # Perform sampling
    (
        all_fitnesses,
        all_descriptors,
        all_extra_scores,
        random_key,
    ) = multi_sample_scoring_function(
        policies_params, random_key, scoring_fn, num_samples
    )

    # Extract final scores
    descriptors = descriptor_extractor(all_descriptors)
    fitnesses = fitness_extractor(all_fitnesses)
    extra_scores = extra_scores_extractor(all_extra_scores, num_samples)

    # Extract reproducibility
    descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors)
    fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses)

    return (
        fitnesses,
        descriptors,
        extra_scores,
        fitnesses_reproducibility,
        descriptors_reproducibility,
        random_key,
    )

train_seq2seq

seq2seq addition example

Inspired by Flax library - https://github.com/google/flax/blob/main/examples/seq2seq/train.py

Copyright 2022 The Flax Authors. Licensed under the Apache License, Version 2.0 (the "License")

get_model(obs_size, teacher_force=False, hidden_size=10)

Returns a seq2seq model.

Parameters:
  • obs_size (int) – the size of the observation.

  • teacher_force (bool) – whether to use teacher forcing.

  • hidden_size (int) – the size of the hidden layer (i.e. the encoding).

Source code in qdax/utils/train_seq2seq.py
def get_model(
    obs_size: int, teacher_force: bool = False, hidden_size: int = 10
) -> Seq2seq:
    """
    Returns a seq2seq model.

    Args:
        obs_size: the size of the observation.
        teacher_force: whether to use teacher forcing.
        hidden_size: the size of the hidden layer (i.e. the encoding).
    """
    return Seq2seq(
        teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size
    )

get_initial_params(model, random_key, encoder_input_shape)

Returns the initial parameters of a seq2seq model.

Parameters:
  • model (Seq2seq) – the seq2seq model.

  • random_key (Any) – the random number generator.

  • encoder_input_shape (Tuple[int, ...]) – the shape of the encoder input.

Source code in qdax/utils/train_seq2seq.py
def get_initial_params(
    model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...]
) -> Dict[str, Any]:
    """
    Returns the initial parameters of a seq2seq model.

    Args:
        model: the seq2seq model.
        random_key: the random number generator.
        encoder_input_shape: the shape of the encoder input.
    """
    random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4)
    variables = model.init(
        {"params": rng1, "lstm": rng2, "dropout": rng3},
        jnp.ones(encoder_input_shape, jnp.float32),
        jnp.ones(encoder_input_shape, jnp.float32),
    )
    return variables["params"]  # type: ignore

train_step(state, batch, lstm_random_key)

Trains for one step.

Parameters:
  • state (TrainState) – the training state.

  • batch (Any) – the batch of data.

  • lstm_random_key (Any) – the random number key.

Source code in qdax/utils/train_seq2seq.py
@jax.jit
def train_step(
    state: train_state.TrainState,
    batch: Array,
    lstm_random_key: PRNGKey,
) -> Tuple[train_state.TrainState, Dict[str, float]]:
    """
    Trains for one step.

    Args:
        state: the training state.
        batch: the batch of data.
        lstm_random_key: the random number key.
    """

    """Trains one step."""
    lstm_key = jax.random.fold_in(lstm_random_key, state.step)
    dropout_key, lstm_key = jax.random.split(lstm_key, 2)

    # Shift input by one to avoid leakage
    batch_decoder = jnp.roll(batch, shift=1, axis=1)

    # Large number as zero token
    batch_decoder = batch_decoder.at[:, 0, :].set(-1000)

    def loss_fn(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
        logits, _ = state.apply_fn(
            {"params": params},
            batch,
            batch_decoder,
            rngs={"lstm": lstm_key, "dropout": dropout_key},
        )

        def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
            return jnp.inner(y - x, y - x) / x.shape[-1]

        res = jax.vmap(mean_squared_error)(
            jnp.reshape(logits.at[:, :-1, ...].get(), (logits.shape[0], -1)),
            jnp.reshape(
                batch_decoder.at[:, 1:, ...].get(), (batch_decoder.shape[0], -1)
            ),
        )
        loss = jnp.mean(res, axis=0)
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss_val, _logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)

    return state, loss_val