Utils¶
qdax.utils
special
¶
metrics
¶
Defines functions to retrieve metrics from training processes.
CSVLogger
¶
Logger to save metrics of an experiment in a csv file during the training process.
Source code in qdax/utils/metrics.py
class CSVLogger:
"""Logger to save metrics of an experiment in a csv file
during the training process.
"""
def __init__(self, filename: str, header: List) -> None:
"""Create the csv logger, create a file and write the
header.
Args:
filename: path to which the file will be saved.
header: header of the csv file.
"""
self._filename = filename
self._header = header
with open(self._filename, "w") as file:
writer = csv.DictWriter(file, fieldnames=self._header)
# write the header
writer.writeheader()
def log(self, metrics: Dict[str, float]) -> None:
"""Log new metrics to the csv file.
Args:
metrics: A dictionary containing the metrics that
need to be saved.
"""
with open(self._filename, "a") as file:
writer = csv.DictWriter(file, fieldnames=self._header)
# write new metrics in a raw
writer.writerow(metrics)
__init__(self, filename, header)
special
¶
Create the csv logger, create a file and write the header.
Parameters: |
|
---|
Source code in qdax/utils/metrics.py
def __init__(self, filename: str, header: List) -> None:
"""Create the csv logger, create a file and write the
header.
Args:
filename: path to which the file will be saved.
header: header of the csv file.
"""
self._filename = filename
self._header = header
with open(self._filename, "w") as file:
writer = csv.DictWriter(file, fieldnames=self._header)
# write the header
writer.writeheader()
log(self, metrics)
¶
Log new metrics to the csv file.
Parameters: |
|
---|
Source code in qdax/utils/metrics.py
def log(self, metrics: Dict[str, float]) -> None:
"""Log new metrics to the csv file.
Args:
metrics: A dictionary containing the metrics that
need to be saved.
"""
with open(self._filename, "a") as file:
writer = csv.DictWriter(file, fieldnames=self._header)
# write new metrics in a raw
writer.writerow(metrics)
default_ga_metrics(repertoire)
¶
Compute the usual GA metrics that one can retrieve from a GA repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/metrics.py
def default_ga_metrics(
repertoire: GARepertoire,
) -> Metrics:
"""Compute the usual GA metrics that one can retrieve
from a GA repertoire.
Args:
repertoire: a GA repertoire
Returns:
a dictionary containing the max fitness of the
repertoire.
"""
# get metrics
max_fitness = jnp.max(repertoire.fitnesses, axis=0)
return {
"max_fitness": max_fitness,
}
default_qd_metrics(repertoire, qd_offset)
¶
Compute the usual QD metrics that one can retrieve from a MAP Elites repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/metrics.py
def default_qd_metrics(repertoire: MapElitesRepertoire, qd_offset: float) -> Metrics:
"""Compute the usual QD metrics that one can retrieve
from a MAP Elites repertoire.
Args:
repertoire: a MAP-Elites repertoire
qd_offset: an offset used to ensure that the QD score
will be positive and increasing with the number
of individuals.
Returns:
a dictionary containing the QD score (sum of fitnesses
modified to be all positive), the max fitness of the
repertoire, the coverage (number of niche filled in
the repertoire).
"""
# get metrics
repertoire_empty = repertoire.fitnesses == -jnp.inf
qd_score = jnp.sum(repertoire.fitnesses, where=~repertoire_empty)
qd_score += qd_offset * jnp.sum(1.0 - repertoire_empty)
coverage = 100 * jnp.mean(1.0 - repertoire_empty)
max_fitness = jnp.max(repertoire.fitnesses)
return {"qd_score": qd_score, "max_fitness": max_fitness, "coverage": coverage}
default_moqd_metrics(repertoire, reference_point)
¶
Compute the MOQD metric given a MOME repertoire and a reference point.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/metrics.py
def default_moqd_metrics(
repertoire: MOMERepertoire, reference_point: jnp.ndarray
) -> Metrics:
"""Compute the MOQD metric given a MOME repertoire and a reference point.
Args:
repertoire: a MOME repertoire.
reference_point: the hypervolume of a pareto front has to be computed
relatively to a reference point.
Returns:
A dictionary containing all the computed metrics.
"""
repertoire_empty = repertoire.fitnesses == -jnp.inf
repertoire_empty = jnp.all(repertoire_empty, axis=-1)
repertoire_not_empty = ~repertoire_empty
repertoire_not_empty = jnp.any(repertoire_not_empty, axis=-1)
coverage = 100 * jnp.mean(repertoire_not_empty)
hypervolume_function = partial(compute_hypervolume, reference_point=reference_point)
moqd_scores = jax.vmap(hypervolume_function)(repertoire.fitnesses)
moqd_scores = jnp.where(repertoire_not_empty, moqd_scores, -jnp.inf)
max_hypervolume = jnp.max(moqd_scores)
max_scores = jnp.max(repertoire.fitnesses, axis=(0, 1))
max_sum_scores = jnp.max(jnp.sum(repertoire.fitnesses, axis=-1), axis=(0, 1))
num_solutions = jnp.sum(~repertoire_empty)
(
pareto_front,
_,
) = repertoire.compute_global_pareto_front()
global_hypervolume = compute_hypervolume(
pareto_front, reference_point=reference_point
)
metrics = {
"moqd_score": moqd_scores,
"max_hypervolume": max_hypervolume,
"max_scores": max_scores,
"max_sum_scores": max_sum_scores,
"coverage": coverage,
"number_solutions": num_solutions,
"global_hypervolume": global_hypervolume,
}
return metrics
pareto_front
¶
Utils to handle pareto fronts.
compute_pareto_dominance(criteria_point, batch_of_criteria)
¶
Returns if a point is pareto dominated given a set of points or not. We use maximization convention here.
criteria_point has shape (num_criteria,) batch_of_criteria has shape (num_points, num_criteria)
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/pareto_front.py
def compute_pareto_dominance(
criteria_point: jnp.ndarray, batch_of_criteria: jnp.ndarray
) -> jnp.ndarray:
"""Returns if a point is pareto dominated given a set of points or not.
We use maximization convention here.
criteria_point has shape (num_criteria,)
batch_of_criteria has shape (num_points, num_criteria)
Args:
criteria_point: a vector of values.
batch_of_criteria: a batch of vector of values.
Returns:
Return booleans when the vector is dominated by the batch.
"""
diff = jnp.subtract(batch_of_criteria, criteria_point)
return jnp.any(jnp.all(diff > 0, axis=-1))
compute_pareto_front(batch_of_criteria)
¶
Returns an array of boolean that states for each element if it is in the pareto front or not.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/pareto_front.py
def compute_pareto_front(batch_of_criteria: jnp.ndarray) -> jnp.ndarray:
"""Returns an array of boolean that states for each element if it is
in the pareto front or not.
Args:
batch_of_criteria: a batch of points of shape (num_points, num_criteria)
Returns:
An array of boolean with the boolean stating if each point is on the
front or not.
"""
func = jax.vmap(lambda x: ~compute_pareto_dominance(x, batch_of_criteria))
return func(batch_of_criteria)
compute_masked_pareto_dominance(criteria_point, batch_of_criteria, mask)
¶
Returns if a point is pareto dominated given a set of points or not. We use maximization convention here.
This function is to be used with constant size batches of criteria, thus a mask is used to know which values are padded.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/pareto_front.py
def compute_masked_pareto_dominance(
criteria_point: jnp.ndarray, batch_of_criteria: jnp.ndarray, mask: Mask
) -> jnp.ndarray:
"""Returns if a point is pareto dominated given a set of points or not.
We use maximization convention here.
This function is to be used with constant size batches of criteria,
thus a mask is used to know which values are padded.
Args:
criteria_point: values to be evaluated, of shape (num_criteria,)
batch_of_criteria: set of points to compare with,
of shape (batch_size, num_criteria)
mask: mask of shape (batch_size,), 1.0 where there is not element,
0 otherwise
Returns:
Boolean assessing if the point is dominated or not.
"""
diff = jnp.subtract(batch_of_criteria, criteria_point)
neutral_values = -jnp.ones_like(diff)
diff = jax.vmap(lambda x1, x2: jnp.where(mask, x1, x2), in_axes=(1, 1), out_axes=1)(
neutral_values, diff
)
return jnp.any(jnp.all(diff > 0, axis=-1))
compute_masked_pareto_front(batch_of_criteria, mask)
¶
Returns an array of boolean that states for each element if it is to be considered or not. This function is to be used with batches of constant size criteria, thus a mask is used to know which values are padded.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/pareto_front.py
def compute_masked_pareto_front(
batch_of_criteria: jnp.ndarray, mask: Mask
) -> jnp.ndarray:
"""Returns an array of boolean that states for each element if it is to be
considered or not. This function is to be used with batches of constant size
criteria, thus a mask is used to know which values are padded.
Args:
batch_of_criteria: data points considered
mask: mask to hide several points
Returns:
An array of boolean stating the points to consider.
"""
func = jax.vmap(
lambda x: ~compute_masked_pareto_dominance(x, batch_of_criteria, mask)
)
return func(batch_of_criteria) * ~mask
compute_hypervolume(pareto_front, reference_point)
¶
Compute hypervolume of a pareto front.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/pareto_front.py
def compute_hypervolume(
pareto_front: ParetoFront[jnp.ndarray], reference_point: jnp.ndarray
) -> jnp.ndarray:
"""Compute hypervolume of a pareto front.
Args:
pareto_front: a pareto front, shape (pareto_size, num_objectives)
reference_point: a reference point to compute the volume, of shape
(num_objectives,)
Returns:
The hypervolume of the pareto front.
"""
# check the number of objectives
custom_message = (
"Hypervolume calculation for more than" " 2 objectives not yet supported."
)
chex.assert_axis_dimension(
tensor=pareto_front,
axis=1,
expected=2,
custom_message=custom_message,
)
# concatenate the reference point to prepare for the area computation
pareto_front = jnp.concatenate( # type: ignore
(pareto_front, jnp.expand_dims(reference_point, axis=0)), axis=0
)
# get ordered indices for the first objective
idx = jnp.argsort(pareto_front[:, 0])
# Note: this orders it in inversely for the second objective
# create the mask - hide fake elements (those having -inf fitness)
mask = pareto_front[idx, :] != -jnp.inf
# sort the front and offset it with the ref point
sorted_front = (pareto_front[idx, :] - reference_point) * mask
# compute area rectangles between successive points
sumdiff = (sorted_front[1:, 0] - sorted_front[:-1, 0]) * sorted_front[1:, 1]
# remove the irrelevant values - where a mask was applied
sumdiff = sumdiff * mask[:-1, 0]
# get the hypervolume by summing the succcessives areas
hypervolume = jnp.sum(sumdiff)
return hypervolume
plotting
¶
get_voronoi_finite_polygons_2d(centroids, radius=None)
¶
Reconstruct infinite voronoi regions in a 2D diagram to finite regions.
Source code in qdax/utils/plotting.py
def get_voronoi_finite_polygons_2d(
centroids: np.ndarray, radius: Optional[float] = None
) -> Tuple[List, np.ndarray]:
"""Reconstruct infinite voronoi regions in a 2D diagram to finite
regions."""
voronoi_diagram = Voronoi(centroids)
if voronoi_diagram.points.shape[1] != 2:
raise ValueError("Requires 2D input")
new_regions = []
new_vertices = voronoi_diagram.vertices.tolist()
center = voronoi_diagram.points.mean(axis=0)
if radius is None:
radius = voronoi_diagram.points.ptp().max()
# Construct a map containing all ridges for a given point
all_ridges: Dict[jnp.ndarray, jnp.ndarray] = {}
for (p1, p2), (v1, v2) in zip(
voronoi_diagram.ridge_points, voronoi_diagram.ridge_vertices
):
all_ridges.setdefault(p1, []).append((p2, v1, v2))
all_ridges.setdefault(p2, []).append((p1, v1, v2))
# Reconstruct infinite regions
for p1, region in enumerate(voronoi_diagram.point_region):
vertices = voronoi_diagram.regions[region]
if all(v >= 0 for v in vertices):
# finite region
new_regions.append(vertices)
continue
# reconstruct a non-finite region
ridges = all_ridges[p1]
new_region = [v for v in vertices if v >= 0]
for p2, v1, v2 in ridges:
if v2 < 0:
v1, v2 = v2, v1
if v1 >= 0:
# finite ridge: already in the region
continue
# Compute the missing endpoint of an infinite ridge
t = voronoi_diagram.points[p2] - voronoi_diagram.points[p1] # tangent
t /= np.linalg.norm(t)
n = np.array([-t[1], t[0]]) # normal
midpoint = voronoi_diagram.points[[p1, p2]].mean(axis=0)
direction = np.sign(np.dot(midpoint - center, n)) * n
far_point = voronoi_diagram.vertices[v2] + direction * radius
new_region.append(len(new_vertices))
new_vertices.append(far_point.tolist())
# sort region counterclockwise
vs = np.asarray([new_vertices[v] for v in new_region])
c = vs.mean(axis=0)
angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
new_region = np.array(new_region)[np.argsort(angles)]
# finish
new_regions.append(new_region.tolist())
return new_regions, np.asarray(new_vertices)
plot_2d_map_elites_repertoire(centroids, repertoire_fitnesses, minval, maxval, repertoire_descriptors=None, ax=None, vmin=None, vmax=None)
¶
Plot a visual representation of a 2d map elites repertoire.
TODO: Use repertoire as input directly. Because this function is very specific to repertoires.
Parameters: |
|
---|
Exceptions: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_2d_map_elites_repertoire(
centroids: jnp.ndarray,
repertoire_fitnesses: jnp.ndarray,
minval: jnp.ndarray,
maxval: jnp.ndarray,
repertoire_descriptors: Optional[jnp.ndarray] = None,
ax: Optional[plt.Axes] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
) -> Tuple[Optional[Figure], Axes]:
"""Plot a visual representation of a 2d map elites repertoire.
TODO: Use repertoire as input directly. Because this
function is very specific to repertoires.
Args:
centroids: the centroids of the repertoire
repertoire_fitnesses: the fitness of the repertoire
minval: minimum values for the descritors
maxval: maximum values for the descriptors
repertoire_descriptors: the descriptors. Defaults to None.
ax: a matplotlib axe for the figure to plot. Defaults to None.
vmin: minimum value for the fitness. Defaults to None. If not given,
the value will be set to the minimum fitness in the repertoire.
vmax: maximum value for the fitness. Defaults to None. If not given,
the value will be set to the maximum fitness in the repertoire.
Raises:
NotImplementedError: does not work for descriptors dimension different
from 2.
Returns:
A figure and axes object, corresponding to the visualisation of the
repertoire.
"""
# TODO: check it and fix it if needed
grid_empty = repertoire_fitnesses == -jnp.inf
num_descriptors = centroids.shape[1]
if num_descriptors != 2:
raise NotImplementedError("Grid plot supports 2 descriptors only for now.")
my_cmap = cm.viridis
fitnesses = repertoire_fitnesses
if vmin is None:
vmin = float(jnp.min(fitnesses[~grid_empty]))
if vmax is None:
vmax = float(jnp.max(fitnesses[~grid_empty]))
# set the parameters
font_size = 12
params = {
"axes.labelsize": font_size,
"legend.fontsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
"text.usetex": False,
"figure.figsize": [10, 10],
}
mpl.rcParams.update(params)
# create the plot object
fig = None
if ax is None:
fig, ax = plt.subplots(facecolor="white", edgecolor="white")
assert (
len(np.array(minval).shape) < 2
), f"minval : {minval} should be float or couple of floats"
assert (
len(np.array(maxval).shape) < 2
), f"maxval : {maxval} should be float or couple of floats"
if len(np.array(minval).shape) == 0 and len(np.array(maxval).shape) == 0:
ax.set_xlim(minval, maxval)
ax.set_ylim(minval, maxval)
else:
ax.set_xlim(minval[0], maxval[0])
ax.set_ylim(minval[1], maxval[1])
ax.set(adjustable="box", aspect="equal")
# create the regions and vertices from centroids
regions, vertices = get_voronoi_finite_polygons_2d(centroids)
norm = Normalize(vmin=vmin, vmax=vmax)
# fill the plot with contours
for region in regions:
polygon = vertices[region]
ax.fill(*zip(*polygon), alpha=0.05, edgecolor="black", facecolor="white", lw=1)
# fill the plot with the colors
for idx, fitness in enumerate(fitnesses):
if fitness > -jnp.inf:
region = regions[idx]
polygon = vertices[region]
ax.fill(*zip(*polygon), alpha=0.8, color=my_cmap(norm(fitness)))
# if descriptors are specified, add points location
if repertoire_descriptors is not None:
descriptors = repertoire_descriptors[~grid_empty]
ax.scatter(
descriptors[:, 0],
descriptors[:, 1],
c=fitnesses[~grid_empty],
cmap=my_cmap,
s=10,
zorder=0,
)
# aesthetic
ax.set_xlabel("Behavior Dimension 1")
ax.set_ylabel("Behavior Dimension 2")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=my_cmap), cax=cax)
cbar.ax.tick_params(labelsize=font_size)
ax.set_title("MAP-Elites Grid")
ax.set_aspect("equal")
return fig, ax
plot_map_elites_results(env_steps, metrics, repertoire, min_bd, max_bd)
¶
Plots three usual QD metrics, namely the coverage, the maximum fitness and the QD-score, along the number of environment steps. This function also plots a visualisation of the final map elites grid obtained. It ensures that those plots are aligned together to give a simple and efficient visualisation of an optimization process.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_map_elites_results(
env_steps: jnp.ndarray,
metrics: Dict,
repertoire: MapElitesRepertoire,
min_bd: jnp.ndarray,
max_bd: jnp.ndarray,
) -> Tuple[Optional[Figure], Axes]:
"""Plots three usual QD metrics, namely the coverage, the maximum fitness
and the QD-score, along the number of environment steps. This function also
plots a visualisation of the final map elites grid obtained. It ensures that
those plots are aligned together to give a simple and efficient visualisation
of an optimization process.
Args:
env_steps: the array containing the number of steps done in the environment.
metrics: a dictionary containing metrics from the optimizatoin process.
repertoire: the final repertoire obtained.
min_bd: the mimimal possible values for the bd.
max_bd: the maximal possible values for the bd.
Returns:
A figure and axes with the plots of the metrics and visualisation of the grid.
"""
# Customize matplotlib params
font_size = 16
params = {
"axes.labelsize": font_size,
"axes.titlesize": font_size,
"legend.fontsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
"text.usetex": False,
"axes.titlepad": 10,
}
mpl.rcParams.update(params)
# Visualize the training evolution and final repertoire
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(40, 10))
# env_steps = jnp.arange(num_iterations) * episode_length * batch_size
axes[0].plot(env_steps, metrics["coverage"])
axes[0].set_xlabel("Environment steps")
axes[0].set_ylabel("Coverage in %")
axes[0].set_title("Coverage evolution during training")
axes[0].set_aspect(0.95 / axes[0].get_data_ratio(), adjustable="box")
axes[1].plot(env_steps, metrics["max_fitness"])
axes[1].set_xlabel("Environment steps")
axes[1].set_ylabel("Maximum fitness")
axes[1].set_title("Maximum fitness evolution during training")
axes[1].set_aspect(0.95 / axes[1].get_data_ratio(), adjustable="box")
axes[2].plot(env_steps, metrics["qd_score"])
axes[2].set_xlabel("Environment steps")
axes[2].set_ylabel("QD Score")
axes[2].set_title("QD Score evolution during training")
axes[2].set_aspect(0.95 / axes[2].get_data_ratio(), adjustable="box")
_, axes = plot_2d_map_elites_repertoire(
centroids=repertoire.centroids,
repertoire_fitnesses=repertoire.fitnesses,
minval=min_bd,
maxval=max_bd,
repertoire_descriptors=repertoire.descriptors,
ax=axes[3],
)
return fig, axes
multiline(xs, ys, c, ax=None, **kwargs)
¶
Plot lines with different colorings (with c a container of numbers mapped to colormap)
Note
len(xs) == len(ys) == len© is the number of line segments len(xs[i]) == len(ys[i]) is the number of points for each line (indexed by i)
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def multiline(
xs: Iterable, ys: Iterable, c: Iterable, ax: Optional[Axes] = None, **kwargs: Any
) -> LineCollection:
"""Plot lines with different colorings (with c a container of numbers mapped to
colormap)
Note:
len(xs) == len(ys) == len(c) is the number of line segments
len(xs[i]) == len(ys[i]) is the number of points for each line (indexed by i)
Args:
xs: First dimension of the trajectory.
ys: Second dimension of the trajectory.
c: Colors - one for each trajectory.
ax: A matplotlib axe. Defaults to None.
Returns:
Return a collection of lines corresponding to the trajectories.
"""
# find axes
ax = plt.gca() if ax is None else ax
# create LineCollection
segments = [np.column_stack([x, y]) for x, y in zip(xs, ys)]
lc = LineCollection(segments, **kwargs)
# set coloring of line segments
# Note: error if c is given as a list here.
lc.set_array(np.asarray(c))
# add lines to axes and rescale
# Note: adding a collection doesn't autoscalee xlim/ylim
ax.add_collection(lc)
ax.autoscale()
return lc
plot_skills_trajectory(trajectories, skills, min_values, max_values)
¶
Plots skills trajectories on a single plot with different colors to recognize the skills.
The plot can contain several trajectories of the same skill.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_skills_trajectory(
trajectories: jnp.ndarray,
skills: jnp.ndarray,
min_values: jnp.ndarray,
max_values: jnp.ndarray,
) -> Tuple[Figure, Axes]:
"""Plots skills trajectories on a single plot with
different colors to recognize the skills.
The plot can contain several trajectories of the same
skill.
Args:
trajectories: skills trajectories
skills: skills corresponding to the given trajectories
min_values: minimum values that can be taken by the steps
of the trajectory
max_values: maximum values that can be taken by the steps
of the trajectory
Returns:
A figure and axes.
"""
# get number of skills
num_skills = skills.shape[1]
# create color from possible skills (indexed from 0 to num skills - 1)
c = skills.argmax(axis=1)
# create the figure
fig, ax = plt.subplots()
# get lines from the trajectories
xs, ys = trajectories
lc = multiline(xs=xs, ys=ys, c=c, ax=ax, cmap="gist_rainbow")
# set aesthetics
ax.set_ylim(min_values[1], max_values[1])
ax.set_xlim(min_values[0], max_values[0])
ax.set_xlabel("Behavior Dimension 1")
ax.set_ylabel("Behavior Dimension 2")
ax.set_aspect("equal")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
axcb = fig.colorbar(lc, cax=cax)
axcb.set_ticks(np.arange(num_skills, dtype=int))
ax.set_title("Skill trajectories")
return fig, ax
plot_mome_pareto_fronts(centroids, repertoire, maxval, minval, axes=None, color_style='hsv', with_global=False)
¶
Plot the pareto fronts from all cells of the mome repertoire.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_mome_pareto_fronts(
centroids: jnp.ndarray,
repertoire: MOMERepertoire,
maxval: float,
minval: float,
axes: Optional[plt.Axes] = None,
color_style: Optional[str] = "hsv",
with_global: Optional[bool] = False,
) -> plt.Axes:
"""Plot the pareto fronts from all cells of the mome repertoire.
Args:
centroids: centroids of the repertoire
repertoire: mome repertoire
maxval: maximum values for the descriptors
minval: minimum values for the descriptors
axes: matplotlib axes. Defaults to None.
color_style: style of the colors used. Defaults to "hsv".
with_global: plot the global pareto front in addition.
Defaults to False.
Returns:
Returns the axes object with the plot.
"""
fitnesses = repertoire.fitnesses
repertoire_descriptors = repertoire.descriptors
assert fitnesses.shape[-1] == repertoire_descriptors.shape[-1] == 2
assert color_style in ["hsv", "spectral"], "color_style must be hsv or spectral"
num_centroids = len(centroids)
grid_empty = jnp.any(fitnesses == -jnp.inf, axis=-1)
# Extract polar coordinates
if color_style == "hsv":
center = jnp.array([(maxval - minval) / 2] * centroids.shape[1])
polars = jnp.stack(
(
jnp.sqrt((jnp.sum((centroids - center) ** 2, axis=-1)))
/ (maxval - minval)
/ jnp.sqrt(2),
jnp.arctan((centroids - center)[:, 1] / (centroids - center)[:, 0]),
),
axis=-1,
)
elif color_style == "spectral":
cmap = cm.get_cmap("Spectral")
if axes is None:
_, axes = plt.subplots(ncols=2, figsize=(12, 6))
for i in range(num_centroids):
if jnp.sum(~grid_empty[i]) > 0:
cell_scores = fitnesses[i][~grid_empty[i]]
cell = repertoire_descriptors[i][~grid_empty[i]]
if color_style == "hsv":
color = vector_to_rgb(polars[i, 1], polars[i, 0])
else:
color = cmap((centroids[i, 0] - minval) / (maxval - minval))
axes[0].plot(cell_scores[:, 0], cell_scores[:, 1], "o", color=color)
axes[1].plot(cell[:, 0], cell[:, 1], "o", color=color)
# create the regions and vertices from centroids
regions, vertices = get_voronoi_finite_polygons_2d(centroids)
# fill the plot with contours
for region in regions:
polygon = vertices[region]
axes[1].fill(
*zip(*polygon), alpha=0.2, edgecolor="black", facecolor="white", lw=1
)
axes[0].set_title("Fitness")
axes[1].set_title("Descriptor")
axes[1].set_xlim(minval, maxval)
axes[1].set_ylim(minval, maxval)
if with_global:
global_pareto_front, pareto_bool = repertoire.compute_global_pareto_front()
global_pareto_descriptors = jnp.concatenate(repertoire_descriptors)[pareto_bool]
axes[0].scatter(
global_pareto_front[:, 0],
global_pareto_front[:, 1],
marker="o",
edgecolors="black",
facecolors="none",
zorder=3,
label="Global Pareto Front",
)
sorted_index = jnp.argsort(global_pareto_front[:, 0])
axes[0].plot(
global_pareto_front[sorted_index, 0],
global_pareto_front[sorted_index, 1],
linestyle="--",
linewidth=2,
color="k",
zorder=3,
)
axes[1].scatter(
global_pareto_descriptors[:, 0],
global_pareto_descriptors[:, 1],
marker="o",
edgecolors="black",
facecolors="none",
zorder=3,
label="Global Pareto Descriptor",
)
return axes
vector_to_rgb(angle, absolute)
¶
Returns a color based on polar coordinates.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def vector_to_rgb(angle: float, absolute: float) -> Any:
"""Returns a color based on polar coordinates.
Args:
angle: a given angle
absolute: a ref
Returns:
An appropriate color.
"""
# normalize angle
angle = angle % (2 * np.pi)
if angle < 0:
angle += 2 * np.pi
# rise absolute to avoid black colours
absolute = (absolute + 0.5) / 1.5
return mpl.colors.hsv_to_rgb((angle / 2 / np.pi, 1, absolute))
plot_global_pareto_front(pareto_front, ax=None, label=None, color=None)
¶
Plots the global Pareto Front.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_global_pareto_front(
pareto_front: jnp.ndarray,
ax: Optional[plt.Axes] = None,
label: Optional[str] = None,
color: Optional[str] = None,
) -> Tuple[Optional[Figure], plt.Axes]:
"""Plots the global Pareto Front.
Args:
pareto_front: a pareto front
ax: a matplotlib ax. Defaults to None.
label: a given label. Defaults to None.
color: a color for the plotted points. Defaults to None.
Returns:
A figure and an axe.
"""
fig = None
if ax is None:
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(pareto_front[:, 0], pareto_front[:, 1], color=color, label=label)
return fig, ax
else:
ax.scatter(pareto_front[:, 0], pareto_front[:, 1], color=color, label=label)
return fig, ax
plot_multidimensional_map_elites_grid(repertoire, minval, maxval, grid_shape, ax=None, vmin=None, vmax=None)
¶
Plot a visual 2D representation of a multidimensional MAP-Elites repertoire (where the dimensionality of descriptors can be greater than 2).
Parameters: |
|
---|
Exceptions: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/plotting.py
def plot_multidimensional_map_elites_grid(
repertoire: MapElitesRepertoire,
minval: jnp.ndarray,
maxval: jnp.ndarray,
grid_shape: Tuple[int, ...],
ax: Optional[plt.Axes] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
) -> Tuple[Optional[Figure], Axes]:
"""Plot a visual 2D representation of a multidimensional MAP-Elites repertoire
(where the dimensionality of descriptors can be greater than 2).
Args:
repertoire: the MAP-Elites repertoire to plot.
minval: minimum values for the descriptors
maxval: maximum values for the descriptors
grid_shape: the resolution of the grid.
ax: a matplotlib axe for the figure to plot. Defaults to None.
vmin: minimum value for the fitness. Defaults to None.
vmax: maximum value for the fitness. Defaults to None.
Raises:
ValueError: the resolution should be an int or a tuple
Returns:
A figure and axes object, corresponding to the visualisation of the
repertoire.
"""
descriptors = repertoire.descriptors
fitnesses = repertoire.fitnesses
is_grid_empty = fitnesses.ravel() == -jnp.inf
num_descriptors = descriptors.shape[1]
if isinstance(grid_shape, tuple):
assert (
len(grid_shape) == num_descriptors
), "grid_shape should have the same length as num_descriptors"
else:
raise ValueError("resolution should be a tuple")
assert np.size(minval) == num_descriptors or np.size(minval) == 1, (
f"minval : {minval} should either be of size 1 "
f"or have the same size as the number of descriptors: {num_descriptors}"
)
assert np.size(maxval) == num_descriptors or np.size(maxval) == 1, (
f"maxval : {maxval} should either be of size 1 "
f"or have the same size as the number of descriptors: {num_descriptors}"
)
non_empty_descriptors = descriptors[~is_grid_empty]
non_empty_fitnesses = fitnesses[~is_grid_empty]
# convert the descriptors to integer coordinates, depending on the resolution.
resolutions_array = jnp.array(grid_shape)
descriptors_integers = jnp.asarray(
jnp.floor(
resolutions_array * (non_empty_descriptors - minval) / (maxval - minval)
),
dtype=jnp.int32,
)
# total number of grid cells along each dimension of the grid
size_grid_x = np.prod(np.array(grid_shape[0::2]))
# convert to int for the 1d case - if not, value 1.0 is given
size_grid_y = np.prod(np.array(grid_shape[1::2]), dtype=int)
# initialise the grid
grid_2d = np.full(
(size_grid_x.item(), size_grid_y.item()),
fill_value=jnp.nan,
)
# put solutions in the grid according to their projected 2-dimensional coordinates
for desc, fit in zip(descriptors_integers, non_empty_fitnesses):
projection_2d = _get_projection_in_2d(desc, grid_shape)
if jnp.isnan(grid_2d[projection_2d]) or fit.item() > grid_2d[projection_2d]:
grid_2d[projection_2d] = fit.item()
# set plot parameters
font_size = 12
params = {
"axes.labelsize": font_size,
"legend.fontsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
"text.usetex": False,
"figure.figsize": [10, 10],
}
mpl.rcParams.update(params)
# create the plot object
fig = None
if ax is None:
fig, ax = plt.subplots(facecolor="white", edgecolor="white")
ax.set(adjustable="box", aspect="equal")
my_cmap = cm.viridis
if vmin is None:
vmin = float(jnp.min(non_empty_fitnesses))
if vmax is None:
vmax = float(jnp.max(non_empty_fitnesses))
ax.imshow(
grid_2d.T,
origin="lower",
aspect="equal",
vmin=vmin,
vmax=vmax,
cmap=my_cmap,
)
# aesthetic
ax.set_xlabel("Behavior Dimension 1")
ax.set_ylabel("Behavior Dimension 2")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
norm = Normalize(vmin=vmin, vmax=vmax)
cbar = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=my_cmap), cax=cax)
cbar.ax.tick_params(labelsize=font_size)
ax.set_title("MAP-Elites Grid")
ax.set_aspect("equal")
def _get_ticks_positions(
total_size_grid_axis: int, step_ticks_on_axis: int
) -> jnp.ndarray:
"""
Get the positions of the ticks on the grid axis.
Args:
total_size_grid_axis: total size of the grid axis
step_ticks_on_axis: step of the ticks
Returns:
The positions of the ticks on the plot.
"""
return np.arange(0, total_size_grid_axis + 1, step_ticks_on_axis) - 0.5
# Ticks position
major_ticks_x = _get_ticks_positions(
size_grid_x.item(), step_ticks_on_axis=np.prod(grid_shape[2::2]).item()
)
minor_ticks_x = _get_ticks_positions(
size_grid_x.item(), step_ticks_on_axis=np.prod(grid_shape[4::2]).item()
)
major_ticks_y = _get_ticks_positions(
size_grid_y.item(), step_ticks_on_axis=np.prod(grid_shape[3::2]).item()
)
minor_ticks_y = _get_ticks_positions(
size_grid_y.item(), step_ticks_on_axis=np.prod(grid_shape[5::2]).item()
)
ax.set_xticks(
major_ticks_x,
)
ax.set_xticks(
minor_ticks_x,
minor=True,
)
ax.set_yticks(
major_ticks_y,
)
ax.set_yticks(
minor_ticks_y,
minor=True,
)
# Ticks aesthetics
ax.tick_params(
which="minor",
color="gray",
labelcolor="gray",
size=5,
)
ax.tick_params(
which="major",
labelsize=font_size,
size=7,
)
ax.grid(which="minor", alpha=1.0, color="#000000", linewidth=0.5)
if len(grid_shape) > 2:
ax.grid(which="major", alpha=1.0, color="#000000", linewidth=2.5)
def _get_positions_labels(
_minval: float, _maxval: float, _number_ticks: int, _step_labels_ticks: int
) -> List[str]:
positions = jnp.linspace(_minval, _maxval, num=_number_ticks)
list_str_positions = []
for index_tick, position in enumerate(positions):
if index_tick % _step_labels_ticks != 0:
character = ""
else:
character = f"{position:.2E}"
list_str_positions.append(character)
# forcing the last tick label
list_str_positions[-1] = f"{positions[-1]:.2E}"
return list_str_positions
number_label_ticks = 4
if len(major_ticks_x) // number_label_ticks > 0:
ax.set_xticklabels(
_get_positions_labels(
_minval=minval[0],
_maxval=maxval[0],
_number_ticks=len(major_ticks_x),
_step_labels_ticks=len(major_ticks_x) // number_label_ticks,
)
)
if len(major_ticks_y) // number_label_ticks > 0:
ax.set_yticklabels(
_get_positions_labels(
_minval=minval[1],
_maxval=maxval[1],
_number_ticks=len(major_ticks_y),
_step_labels_ticks=len(major_ticks_y) // number_label_ticks,
)
)
return fig, ax
sampling
¶
Core components of the MAP-Elites-sampling algorithm.
average(quantities)
¶
Default expectation extractor using average.
Source code in qdax/utils/sampling.py
@jax.jit
def average(quantities: jnp.ndarray) -> jnp.ndarray:
"""Default expectation extractor using average."""
return jnp.average(quantities, axis=1)
median(quantities)
¶
Alternative expectation extractor using median. More robust to outliers than average.
Source code in qdax/utils/sampling.py
@jax.jit
def median(quantities: jnp.ndarray) -> jnp.ndarray:
"""Alternative expectation extractor using median.
More robust to outliers than average."""
return jnp.median(quantities, axis=1)
mode(quantities)
¶
Alternative expectation extractor using mode. More robust to outliers than average. WARNING: for multidimensional objects such as descriptor, do dimension-wise selection.
Source code in qdax/utils/sampling.py
@jax.jit
def mode(quantities: jnp.ndarray) -> jnp.ndarray:
"""Alternative expectation extractor using mode.
More robust to outliers than average.
WARNING: for multidimensional objects such as descriptor, do
dimension-wise selection.
"""
def _mode(quantity: jnp.ndarray) -> jnp.ndarray:
# Ensure correct dimensions for both single and multi-dimension
quantity = jnp.reshape(quantity, (quantity.shape[0], -1))
# Dimension-wise voting in case of multi-dimension
def _dim_mode(dim_quantity: jnp.ndarray) -> jnp.ndarray:
unique_vals, counts = jnp.unique(
dim_quantity, return_counts=True, size=dim_quantity.size
)
return unique_vals[jnp.argmax(counts)]
# vmap over dimensions
return jnp.squeeze(jax.vmap(_dim_mode)(jnp.transpose(quantity)))
# vmap over individuals
return jax.vmap(_mode)(quantities)
closest(quantities)
¶
Alternative expectation extractor selecting individual that has the minimum distance to all other individuals. This is an approximation of the geometric median. More robust to outliers than average.
Source code in qdax/utils/sampling.py
@jax.jit
def closest(quantities: jnp.ndarray) -> jnp.ndarray:
"""Alternative expectation extractor selecting individual
that has the minimum distance to all other individuals. This
is an approximation of the geometric median.
More robust to outliers than average."""
def _closest(values: jnp.ndarray) -> jnp.ndarray:
def distance(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return jnp.sqrt(jnp.sum(jnp.square(x - y)))
distances = jax.vmap(
jax.vmap(partial(distance), in_axes=(None, 0)), in_axes=(0, None)
)(values, values)
return values[jnp.argmin(jnp.mean(distances, axis=0))]
return jax.vmap(_closest)(quantities)
std(quantities)
¶
Default reproducibility extractor using standard deviation.
Source code in qdax/utils/sampling.py
@jax.jit
def std(quantities: jnp.ndarray) -> jnp.ndarray:
"""Default reproducibility extractor using standard deviation."""
return jnp.std(quantities, axis=1)
mad(quantities)
¶
Alternative reproducibility extractor using Median Absolute Deviation. More robust to outliers than standard deviation.
Source code in qdax/utils/sampling.py
@jax.jit
def mad(quantities: jnp.ndarray) -> jnp.ndarray:
"""Alternative reproducibility extractor using Median Absolute Deviation.
More robust to outliers than standard deviation."""
num_samples = quantities.shape[1]
median = jnp.repeat(
jnp.median(quantities, axis=1, keepdims=True), num_samples, axis=1
)
return jnp.median(jnp.abs(quantities - median), axis=1)
iqr(quantities)
¶
Alternative reproducibility extractor using Inter-Quartile Range. More robust to outliers than standard deviation.
Source code in qdax/utils/sampling.py
@jax.jit
def iqr(quantities: jnp.ndarray) -> jnp.ndarray:
"""Alternative reproducibility extractor using Inter-Quartile Range.
More robust to outliers than standard deviation."""
q1 = jnp.quantile(quantities, 0.25, axis=1)
q4 = jnp.quantile(quantities, 0.75, axis=1)
return q4 - q1
dummy_extra_scores_extractor(extra_scores, num_samples)
¶
Extract the final extra scores of a policy from multiple samples of the same policy in the environment. This Dummy implementation just return the full concatenate extra_score of all samples without extra computation.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/sampling.py
@partial(jax.jit, static_argnames=("num_samples",))
def dummy_extra_scores_extractor(
extra_scores: ExtraScores,
num_samples: int,
) -> ExtraScores:
"""
Extract the final extra scores of a policy from multiple samples of
the same policy in the environment.
This Dummy implementation just return the full concatenate extra_score
of all samples without extra computation.
Args:
extra_scores: extra scores of the samples
num_samples: the number of samples used
Returns:
the new extra scores after extraction
"""
return extra_scores
multi_sample_scoring_function(policies_params, random_key, scoring_fn, num_samples)
¶
Wrap scoring_function to perform sampling.
This function returns the fitnesses, descriptors, and extra_scores computed over num_samples evaluations with the scoring_fn.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/sampling.py
@partial(
jax.jit,
static_argnames=(
"scoring_fn",
"num_samples",
),
)
def multi_sample_scoring_function(
policies_params: Genotype,
random_key: RNGKey,
scoring_fn: Callable[
[Genotype, RNGKey],
Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
],
num_samples: int,
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""
Wrap scoring_function to perform sampling.
This function returns the fitnesses, descriptors, and extra_scores computed
over num_samples evaluations with the scoring_fn.
Args:
policies_params: policies to evaluate
random_key: JAX random key
scoring_fn: scoring function used for evaluation
num_samples: number of samples to generate for each individual
Returns:
(n, num_samples) array of fitnesses,
(n, num_samples, num_descriptors) array of descriptors,
dict with num_samples extra_scores per individual,
JAX random key
"""
random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=num_samples)
# evaluate
sample_scoring_fn = jax.vmap(
scoring_fn,
# vectorizing over axis 0 vectorizes over the num_samples random keys
in_axes=(None, 0),
# indicates that the vectorized axis will become axis 1, i.e., the final
# output is shape (batch_size, num_samples, ...)
out_axes=1,
)
all_fitnesses, all_descriptors, all_extra_scores, _ = sample_scoring_fn(
policies_params, keys
)
return all_fitnesses, all_descriptors, all_extra_scores, random_key
sampling(policies_params, random_key, scoring_fn, num_samples, extra_scores_extractor=<PjitFunction of <function dummy_extra_scores_extractor at 0x7f8b72fde280>>, fitness_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, descriptor_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>)
¶
Wrap scoring_function to perform sampling.
This function return the expected fitnesses and descriptors for each
individual over num_samples
evaluations using the provided extractor
function for the fitness and the descriptor.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/sampling.py
@partial(
jax.jit,
static_argnames=(
"scoring_fn",
"num_samples",
"extra_scores_extractor",
"fitness_extractor",
"descriptor_extractor",
),
)
def sampling(
policies_params: Genotype,
random_key: RNGKey,
scoring_fn: Callable[
[Genotype, RNGKey],
Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
],
num_samples: int,
extra_scores_extractor: Callable[
[ExtraScores, int], ExtraScores
] = dummy_extra_scores_extractor,
fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""Wrap scoring_function to perform sampling.
This function return the expected fitnesses and descriptors for each
individual over `num_samples` evaluations using the provided extractor
function for the fitness and the descriptor.
Args:
policies_params: policies to evaluate
random_key: JAX random key
scoring_fn: scoring function used for evaluation
num_samples: number of samples to generate for each individual
extra_scores_extractor: function to extract the extra_scores from
multiple samples of the same policy.
fitness_extractor: function to extract the fitness expectation from
multiple samples of the same policy.
descriptor_extractor: function to extract the descriptor expectation
from multiple samples of the same policy.
Returns:
The expected fitnesses, descriptors and extra_scores of the individuals
A new random key
"""
# Perform sampling
(
all_fitnesses,
all_descriptors,
all_extra_scores,
random_key,
) = multi_sample_scoring_function(
policies_params, random_key, scoring_fn, num_samples
)
# Extract final scores
descriptors = descriptor_extractor(all_descriptors)
fitnesses = fitness_extractor(all_fitnesses)
extra_scores = extra_scores_extractor(all_extra_scores, num_samples)
return fitnesses, descriptors, extra_scores, random_key
sampling_reproducibility(policies_params, random_key, scoring_fn, num_samples, extra_scores_extractor=<PjitFunction of <function dummy_extra_scores_extractor at 0x7f8b72fde280>>, fitness_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, descriptor_extractor=<PjitFunction of <function average at 0x7f8b72fed700>>, fitness_reproducibility_extractor=<PjitFunction of <function std at 0x7f8b72fdd670>>, descriptor_reproducibility_extractor=<PjitFunction of <function std at 0x7f8b72fdd670>>)
¶
Wrap scoring_function to perform sampling and compute the expectation and reproduciblity.
This function return the reproducibility of fitnesses and descriptors for each
individual over num_samples
evaluations using the provided extractor
function for the fitness and the descriptor.
Parameters: |
|
---|
Returns: |
|
---|
Source code in qdax/utils/sampling.py
@partial(
jax.jit,
static_argnames=(
"scoring_fn",
"num_samples",
"extra_scores_extractor",
"fitness_extractor",
"descriptor_extractor",
"fitness_reproducibility_extractor",
"descriptor_reproducibility_extractor",
),
)
def sampling_reproducibility(
policies_params: Genotype,
random_key: RNGKey,
scoring_fn: Callable[
[Genotype, RNGKey],
Tuple[Fitness, Descriptor, ExtraScores, RNGKey],
],
num_samples: int,
extra_scores_extractor: Callable[
[ExtraScores, int], ExtraScores
] = dummy_extra_scores_extractor,
fitness_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
descriptor_extractor: Callable[[jnp.ndarray], jnp.ndarray] = average,
fitness_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std,
descriptor_reproducibility_extractor: Callable[[jnp.ndarray], jnp.ndarray] = std,
) -> Tuple[Fitness, Descriptor, ExtraScores, Fitness, Descriptor, RNGKey]:
"""Wrap scoring_function to perform sampling and compute the
expectation and reproduciblity.
This function return the reproducibility of fitnesses and descriptors for each
individual over `num_samples` evaluations using the provided extractor
function for the fitness and the descriptor.
Args:
policies_params: policies to evaluate
random_key: JAX random key
scoring_fn: scoring function used for evaluation
num_samples: number of samples to generate for each individual
extra_scores_extractor: function to extract the extra_scores from
multiple samples of the same policy.
fitness_extractor: function to extract the fitness expectation from
multiple samples of the same policy.
descriptor_extractor: function to extract the descriptor expectation
from multiple samples of the same policy.
fitness_reproducibility_extractor: function to extract the fitness
reproducibility from multiple samples of the same policy.
descriptor_reproducibility_extractor: function to extract the descriptor
reproducibility from multiple samples of the same policy.
Returns:
The expected fitnesses, descriptors and extra_scores of the individuals
The fitnesses and descriptors reproducibility of the individuals
A new random key
"""
# Perform sampling
(
all_fitnesses,
all_descriptors,
all_extra_scores,
random_key,
) = multi_sample_scoring_function(
policies_params, random_key, scoring_fn, num_samples
)
# Extract final scores
descriptors = descriptor_extractor(all_descriptors)
fitnesses = fitness_extractor(all_fitnesses)
extra_scores = extra_scores_extractor(all_extra_scores, num_samples)
# Extract reproducibility
descriptors_reproducibility = descriptor_reproducibility_extractor(all_descriptors)
fitnesses_reproducibility = fitness_reproducibility_extractor(all_fitnesses)
return (
fitnesses,
descriptors,
extra_scores,
fitnesses_reproducibility,
descriptors_reproducibility,
random_key,
)
train_seq2seq
¶
seq2seq addition example
Inspired by Flax library - https://github.com/google/flax/blob/main/examples/seq2seq/train.py
Copyright 2022 The Flax Authors. Licensed under the Apache License, Version 2.0 (the "License")
get_model(obs_size, teacher_force=False, hidden_size=10)
¶
Returns a seq2seq model.
Parameters: |
|
---|
Source code in qdax/utils/train_seq2seq.py
def get_model(
obs_size: int, teacher_force: bool = False, hidden_size: int = 10
) -> Seq2seq:
"""
Returns a seq2seq model.
Args:
obs_size: the size of the observation.
teacher_force: whether to use teacher forcing.
hidden_size: the size of the hidden layer (i.e. the encoding).
"""
return Seq2seq(
teacher_force=teacher_force, hidden_size=hidden_size, obs_size=obs_size
)
get_initial_params(model, random_key, encoder_input_shape)
¶
Returns the initial parameters of a seq2seq model.
Parameters: |
|
---|
Source code in qdax/utils/train_seq2seq.py
def get_initial_params(
model: Seq2seq, random_key: PRNGKey, encoder_input_shape: Tuple[int, ...]
) -> Dict[str, Any]:
"""
Returns the initial parameters of a seq2seq model.
Args:
model: the seq2seq model.
random_key: the random number generator.
encoder_input_shape: the shape of the encoder input.
"""
random_key, rng1, rng2, rng3 = jax.random.split(random_key, 4)
variables = model.init(
{"params": rng1, "lstm": rng2, "dropout": rng3},
jnp.ones(encoder_input_shape, jnp.float32),
jnp.ones(encoder_input_shape, jnp.float32),
)
return variables["params"] # type: ignore
train_step(state, batch, lstm_random_key)
¶
Trains for one step.
Parameters: |
|
---|
Source code in qdax/utils/train_seq2seq.py
@jax.jit
def train_step(
state: train_state.TrainState,
batch: Array,
lstm_random_key: PRNGKey,
) -> Tuple[train_state.TrainState, Dict[str, float]]:
"""
Trains for one step.
Args:
state: the training state.
batch: the batch of data.
lstm_random_key: the random number key.
"""
"""Trains one step."""
lstm_key = jax.random.fold_in(lstm_random_key, state.step)
dropout_key, lstm_key = jax.random.split(lstm_key, 2)
# Shift input by one to avoid leakage
batch_decoder = jnp.roll(batch, shift=1, axis=1)
# Large number as zero token
batch_decoder = batch_decoder.at[:, 0, :].set(-1000)
def loss_fn(params: Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
logits, _ = state.apply_fn(
{"params": params},
batch,
batch_decoder,
rngs={"lstm": lstm_key, "dropout": dropout_key},
)
def mean_squared_error(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
return jnp.inner(y - x, y - x) / x.shape[-1]
res = jax.vmap(mean_squared_error)(
jnp.reshape(logits.at[:, :-1, ...].get(), (logits.shape[0], -1)),
jnp.reshape(
batch_decoder.at[:, 1:, ...].get(), (batch_decoder.shape[0], -1)
),
)
loss = jnp.mean(res, axis=0)
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss_val, _logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss_val