diff --git a/docs/source/conf.py b/docs/source/conf.py
index 89f9104f1..8c77a288a 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -173,6 +173,7 @@
nb_execution_mode = "force" # "off", "force", "cache", "auto"
nb_execution_allow_errors = False
nb_merge_streams = True
+nb_scroll_outputs = True
# Notebook cell execution timeout; defaults to 30.
nb_execution_timeout = 1000
diff --git a/docs/source/how_to/how_to_change_plotting_backend.ipynb b/docs/source/how_to/how_to_change_plotting_backend.ipynb
index 0e04d2cf8..14b86277c 100644
--- a/docs/source/how_to/how_to_change_plotting_backend.ipynb
+++ b/docs/source/how_to/how_to_change_plotting_backend.ipynb
@@ -38,6 +38,10 @@
"\n",
"The returned figure object is a [`matplotlib.axes.Axes`](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html).\n",
"\n",
+ "```{note}\n",
+ "In case of grid plots (such as `convergence_plot` or `slice_plot`), the returned object is a 2-dimensional numpy array of `Axes` objects: [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html)[[`matplotlib.axes.Axes`]](https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html) of shape `(n_rows, n_cols)`.\n",
+ "```\n",
+ "\n",
":::\n",
"\n",
"::::"
diff --git a/src/optimagic/visualization/backends.py b/src/optimagic/visualization/backends.py
index 0b10c1710..91212b100 100644
--- a/src/optimagic/visualization/backends.py
+++ b/src/optimagic/visualization/backends.py
@@ -1,5 +1,7 @@
-from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable
+import itertools
+from typing import TYPE_CHECKING, Any, Literal, Protocol, overload, runtime_checkable
+import numpy as np
import plotly.graph_objects as go
from optimagic.config import IS_MATPLOTLIB_INSTALLED
@@ -25,6 +27,37 @@ def __call__(
legend_properties: dict[str, Any] | None,
margin_properties: dict[str, Any] | None,
horizontal_line: float | None,
+ subplot: Any | None = None,
+ ) -> Any:
+ ...
+ """Protocol of the line_plot function used for type checking.
+
+ Args:
+ ...: All other argument descriptions can be found in the docstring of the
+ `line_plot` function.
+ subplot: The subplot to which the lines should be plotted. The type of this
+ argument depends on the backend used. If not provided, a new figure is
+ created.
+
+ """
+
+
+@runtime_checkable
+class GridLinePlotFunction(Protocol):
+ def __call__(
+ self,
+ lines_list: list[list[LineData]],
+ *,
+ n_rows: int,
+ n_cols: int,
+ titles: list[str] | None,
+ xlabel: str | None,
+ ylabel: str | None,
+ template: str | None,
+ height: int | None,
+ width: int | None,
+ legend_properties: dict[str, Any] | None,
+ margin_properties: dict[str, Any] | None,
) -> Any: ...
@@ -40,28 +73,52 @@ def _line_plot_plotly(
legend_properties: dict[str, Any] | None,
margin_properties: dict[str, Any] | None,
horizontal_line: float | None,
+ subplot: tuple[go.Figure, int, int] | None = None,
) -> go.Figure:
+ """Create a line plot using Plotly.
+
+ Args:
+ ...: All other argument descriptions can be found in the docstring of the
+ `line_plot` function.
+ subplot: A tuple specifying the subplot to which the lines should be plotted.
+ The tuple contains the Plotly `Figure` object, the row index, and the column
+ index of the subplot. If not provided, a new `Figure` object is created.
+
+ Returns:
+ A Plotly Figure object.
+
+ """
if template is None:
template = "simple_white"
- fig = go.Figure()
+ if subplot is None:
+ fig = go.Figure()
+ row, col = None, None
+ else:
+ fig, row, col = subplot
fig.update_layout(
title=title,
- xaxis_title=xlabel.format(linebreak="
") if xlabel else None,
- yaxis_title=ylabel,
template=template,
height=height,
width=width,
legend=legend_properties,
margin=margin_properties,
)
+ fig.update_xaxes(
+ title=xlabel.format(linebreak="
") if xlabel else None, row=row, col=col
+ )
+ fig.update_yaxes(
+ title=ylabel.format(linebreak="
") if ylabel else None, row=row, col=col
+ )
if horizontal_line is not None:
fig.add_hline(
y=horizontal_line,
line_width=fig.layout.yaxis.linewidth or 1,
opacity=1.0,
+ row=row,
+ col=col,
)
for line in lines:
@@ -72,8 +129,54 @@ def _line_plot_plotly(
line_color=line.color,
mode="lines",
showlegend=line.show_in_legend,
+ legendgroup=line.name,
+ )
+ fig.add_trace(trace, row=row, col=col)
+
+ return fig
+
+
+def _grid_line_plot_plotly(
+ lines_list: list[list[LineData]],
+ *,
+ n_rows: int,
+ n_cols: int,
+ titles: list[str] | None,
+ xlabel: str | None,
+ ylabel: str | None,
+ template: str | None,
+ height: int | None,
+ width: int | None,
+ legend_properties: dict[str, Any] | None,
+ margin_properties: dict[str, Any] | None,
+) -> go.Figure:
+ from plotly.subplots import make_subplots
+
+ fig = make_subplots(
+ rows=n_rows,
+ cols=n_cols,
+ subplot_titles=titles,
+ horizontal_spacing=0.3 / n_cols,
+ )
+
+ for lines, (row, col) in zip(
+ lines_list,
+ itertools.product(range(1, n_rows + 1), range(1, n_cols + 1)),
+ strict=False,
+ ):
+ _line_plot_plotly(
+ lines,
+ title=None,
+ xlabel=xlabel,
+ ylabel=ylabel,
+ template=template,
+ height=height,
+ width=width,
+ legend_properties=legend_properties,
+ margin_properties=margin_properties,
+ horizontal_line=None,
+ subplot=(fig, row, col),
)
- fig.add_trace(trace)
return fig
@@ -90,7 +193,21 @@ def _line_plot_matplotlib(
legend_properties: dict[str, Any] | None,
margin_properties: dict[str, Any] | None,
horizontal_line: float | None,
+ subplot: "plt.Axes | None" = None,
) -> "plt.Axes":
+ """Create a line plot using Matplotlib.
+
+ Args:
+ ...: All other argument descriptions can be found in the docstring of the
+ `line_plot` function.
+ subplot: A Matplotlib `Axes` object to which the lines should be plotted.
+ If provided, the plot is drawn on the given `Axes`. If not provided,
+ a new `Figure` and `Axes` are created.
+
+ Returns:
+ A Matplotlib Axes object.
+
+ """
import matplotlib.pyplot as plt
# In interactive environments (like Jupyter), explicitly enable matplotlib's
@@ -105,10 +222,14 @@ def _line_plot_matplotlib(
template = "default"
with plt.style.context(template):
- px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
- fig, ax = plt.subplots(
- figsize=(width * px, height * px) if width and height else None
- )
+ if subplot is None:
+ px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
+ fig, ax = plt.subplots(
+ figsize=(width * px, height * px) if width and height else None,
+ layout="constrained",
+ )
+ else:
+ ax = subplot
if horizontal_line is not None:
ax.axhline(
@@ -128,24 +249,62 @@ def _line_plot_matplotlib(
ax.set(
title=title,
xlabel=xlabel.format(linebreak="\n") if xlabel else None,
- ylabel=ylabel,
+ ylabel=ylabel.format(linebreak="\n") if ylabel else None,
)
- if legend_properties is None:
- legend_properties = {}
- ax.legend(**legend_properties)
-
- fig.tight_layout()
+ if subplot is None and legend_properties is not None:
+ fig.legend(**legend_properties)
return ax
-BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: dict[
- str, tuple[bool, LinePlotFunction]
-] = {
- "plotly": (True, _line_plot_plotly),
- "matplotlib": (IS_MATPLOTLIB_INSTALLED, _line_plot_matplotlib),
-}
+def _grid_line_plot_matplotlib(
+ lines_list: list[list[LineData]],
+ *,
+ n_rows: int,
+ n_cols: int,
+ titles: list[str] | None,
+ xlabel: str | None,
+ ylabel: str | None,
+ template: str | None,
+ height: int | None,
+ width: int | None,
+ legend_properties: dict[str, Any] | None,
+ margin_properties: dict[str, Any] | None,
+) -> np.ndarray:
+ import matplotlib.pyplot as plt
+
+ px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
+ fig, axes = plt.subplots(
+ nrows=n_rows,
+ ncols=n_cols,
+ squeeze=False, # always return a 2D array of axes
+ figsize=(width * px, height * px) if width and height else None,
+ layout="constrained",
+ )
+
+ for i, (row, col) in enumerate(itertools.product(range(n_rows), range(n_cols))):
+ if i >= len(lines_list):
+ axes[row, col].set_visible(False)
+ continue
+
+ _line_plot_matplotlib(
+ lines_list[i],
+ title=titles[i] if titles else None,
+ xlabel=xlabel,
+ ylabel=ylabel,
+ template=template,
+ height=None,
+ width=None,
+ legend_properties=None,
+ margin_properties=None,
+ horizontal_line=None,
+ subplot=axes[row, col],
+ )
+
+ fig.legend(**legend_properties or {})
+
+ return axes
def line_plot(
@@ -184,6 +343,108 @@ def line_plot(
A figure object corresponding to the specified backend.
"""
+ _line_plot_backend_function = _get_plot_function(backend, grid_plot=False)
+
+ fig = _line_plot_backend_function(
+ lines,
+ title=title,
+ xlabel=xlabel,
+ ylabel=ylabel,
+ template=template,
+ height=height,
+ width=width,
+ legend_properties=legend_properties,
+ margin_properties=margin_properties,
+ horizontal_line=horizontal_line,
+ )
+
+ return fig
+
+
+def grid_line_plot(
+ lines_list: list[list[LineData]],
+ backend: Literal["plotly", "matplotlib"] = "plotly",
+ *,
+ n_rows: int,
+ n_cols: int,
+ titles: list[str] | None = None,
+ xlabel: str | None = None,
+ ylabel: str | None = None,
+ template: str | None = None,
+ height: int | None = None,
+ width: int | None = None,
+ legend_properties: dict[str, Any] | None = None,
+ margin_properties: dict[str, Any] | None = None,
+) -> Any:
+ """Create a grid of line plots corresponding to the specified backend.
+
+ Args:
+ lines_list: A list where each element is a list of objects containing data
+ for the lines in a subplot. The order of sublists determines the order
+ of subplots in the grid (row-wise), and the order of lines within each
+ sublist determines the order of lines in that subplot.
+ backend: The backend to use for plotting.
+ n_rows: Number of rows in the grid.
+ n_cols: Number of columns in the grid.
+ titles: Titles for each subplot in the grid.
+ xlabel: Label for the x-axis of each subplot.
+ ylabel: Label for the y-axis of each subplot.
+ template: Backend-specific template for styling the plots.
+ height: Height of the entire grid plot (in pixels).
+ width: Width of the entire grid plot (in pixels).
+ legend_properties: Backend-specific properties for the legend.
+ margin_properties: Backend-specific properties for the plot margins.
+
+ Returns:
+ A figure object corresponding to the specified backend.
+
+ """
+ _grid_line_plot_backend_function = _get_plot_function(backend, grid_plot=True)
+
+ fig = _grid_line_plot_backend_function(
+ lines_list,
+ n_rows=n_rows,
+ n_cols=n_cols,
+ titles=titles,
+ xlabel=xlabel,
+ ylabel=ylabel,
+ template=template,
+ height=height,
+ width=width,
+ legend_properties=legend_properties,
+ margin_properties=margin_properties,
+ )
+
+ return fig
+
+
+BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION: dict[
+ str, tuple[bool, LinePlotFunction, GridLinePlotFunction]
+] = {
+ "plotly": (True, _line_plot_plotly, _grid_line_plot_plotly),
+ "matplotlib": (
+ IS_MATPLOTLIB_INSTALLED,
+ _line_plot_matplotlib,
+ _grid_line_plot_matplotlib,
+ ),
+}
+
+
+@overload
+def _get_plot_function(
+ backend: Literal["plotly", "matplotlib"], grid_plot: Literal[False]
+) -> LinePlotFunction: ...
+
+
+@overload
+def _get_plot_function(
+ backend: Literal["plotly", "matplotlib"], grid_plot: Literal[True]
+) -> GridLinePlotFunction: ...
+
+
+def _get_plot_function(
+ backend: str, grid_plot: bool
+) -> LinePlotFunction | GridLinePlotFunction:
if backend not in BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION:
msg = (
f"Invalid plotting backend '{backend}'. "
@@ -192,9 +453,11 @@ def line_plot(
)
raise InvalidPlottingBackendError(msg)
- _is_backend_available, _line_plot_backend_function = (
- BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend]
- )
+ (
+ _is_backend_available,
+ _line_plot_backend_function,
+ _grid_line_plot_backend_function,
+ ) = BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION[backend]
if not _is_backend_available:
msg = (
@@ -204,17 +467,7 @@ def line_plot(
)
raise NotInstalledError(msg)
- fig = _line_plot_backend_function(
- lines,
- title=title,
- xlabel=xlabel,
- ylabel=ylabel,
- template=template,
- height=height,
- width=width,
- legend_properties=legend_properties,
- margin_properties=margin_properties,
- horizontal_line=horizontal_line,
- )
-
- return fig
+ if grid_plot:
+ return _grid_line_plot_backend_function
+ else:
+ return _line_plot_backend_function
diff --git a/src/optimagic/visualization/convergence_plot.py b/src/optimagic/visualization/convergence_plot.py
index 34dd99baa..ba14d5c29 100644
--- a/src/optimagic/visualization/convergence_plot.py
+++ b/src/optimagic/visualization/convergence_plot.py
@@ -1,39 +1,85 @@
+from typing import Any, Literal
+
import numpy as np
-import plotly.express as px
-import plotly.graph_objects as go
+import pandas as pd
from optimagic.benchmarking.process_benchmark_results import (
process_benchmark_results,
)
-from optimagic.config import PLOTLY_TEMPLATE
+from optimagic.config import DEFAULT_PALETTE
from optimagic.utilities import propose_alternatives
-from optimagic.visualization.plotting_utilities import create_grid_plot, create_ind_dict
+from optimagic.visualization.backends import grid_line_plot, line_plot
+from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle
+
+BACKEND_TO_CONVERGENCE_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
+ "plotly": {},
+ "matplotlib": {"loc": "outside right upper", "fontsize": "x-small"},
+}
+
+BACKEND_TO_CONVERGENCE_PLOT_MARGIN_PROPERTIES: dict[str, dict[str, int]] = {
+ "plotly": {"l": 10, "r": 10, "t": 30, "b": 10},
+ # "matplotlib": handles margins automatically via constrained layout
+}
+
+OUTCOME_TO_CONVERGENCE_PLOT_YLABEL: dict[str, str] = {
+ "criterion": "Current Function Value",
+ "monotone_criterion": "Best Function Value Found So Far",
+ "criterion_normalized": (
+ "Share of Function Distance to Optimum{linebreak}"
+ "Missing From Current Criterion Value"
+ ),
+ "monotone_criterion_normalized": (
+ "Share of Function Distance to Optimum{linebreak}Missing From Best So Far"
+ ),
+ "parameter_distance": "Distance Between Current and Optimal Parameters",
+ "parameter_distance_normalized": (
+ "Share of Parameter Distance to Optimum{linebreak}"
+ "Missing From Current Parameters"
+ ),
+ "monotone_parameter_distance_normalized": (
+ "Share of the Parameter Distance to Optimum{linebreak}"
+ "Missing From the Best Parameters So Far"
+ ),
+ "monotone_parameter_distance": (
+ "Distance Between the Best Parameters So Far{linebreak}"
+ "and the Optimal Parameters"
+ ),
+}
+
+RUNTIME_MEASURE_TO_CONVERGENCE_PLOT_XLABEL: dict[str, str] = {
+ "n_evaluations": "Number of Function Evaluations",
+ "walltime": "Elapsed Time",
+ "n_batches": "Number of Batches",
+}
def convergence_plot(
- problems,
- results,
+ problems: dict[str, dict[str, Any]],
+ results: dict[tuple[str, str], dict[str, Any]],
*,
- problem_subset=None,
- algorithm_subset=None,
- n_cols=2,
- distance_measure="criterion",
- monotone=True,
- normalize_distance=True,
- runtime_measure="n_evaluations",
- stopping_criterion="y",
- x_precision=1e-4,
- y_precision=1e-4,
- combine_plots_in_grid=True,
- template=PLOTLY_TEMPLATE,
- palette=px.colors.qualitative.Plotly,
-):
+ problem_subset: list[str] | None = None,
+ algorithm_subset: list[str] | None = None,
+ n_cols: int = 2,
+ distance_measure: Literal["criterion", "parameter_distance"] = "criterion",
+ monotone: bool = True,
+ normalize_distance: bool = True,
+ runtime_measure: Literal[
+ "n_evaluations", "walltime", "n_batches"
+ ] = "n_evaluations",
+ stopping_criterion: Literal["x", "y", "x_and_y", "x_or_y"] | None = "y",
+ x_precision: float = 1e-4,
+ y_precision: float = 1e-4,
+ combine_plots_in_grid: bool = True,
+ backend: Literal["plotly", "matplotlib"] = "plotly",
+ template: str | None = None,
+ palette: list[str] | str = DEFAULT_PALETTE,
+) -> Any:
"""Plot convergence of optimizers for a set of problems.
This creates a grid of plots, showing the convergence of the different
algorithms on each problem. The faster a line falls, the faster the algorithm
improved on the problem. The algorithm converged where its line reaches 0
- (if normalize_distance is True) or the horizontal blue line labeled "true solution".
+ (if normalize_distance is True) or the horizontal line labeled "true solution".
Each plot shows on the x axis the runtime_measure, which can be walltime, number
of evaluations or number of batches. Each algorithm's convergence is a line in the
@@ -43,49 +89,52 @@ def convergence_plot(
solution is one.
Args:
- problems (dict): optimagic benchmarking problems dictionary. Keys are the
- problem names. Values contain information on the problem, including the
- solution value.
- results (dict): optimagic benchmarking results dictionary. Keys are
- tuples of the form (problem, algorithm), values are dictionaries of the
- collected information on the benchmark run, including 'criterion_history'
- and 'time_history'.
- problem_subset (list, optional): List of problem names. These must be a subset
- of the keys of the problems dictionary. If provided the convergence plot is
- only created for the problems specified in this list.
- algorithm_subset (list, optional): List of algorithm names. These must be a
- subset of the keys of the optimizer_options passed to run_benchmark. If
- provided only the convergence of the given algorithms are shown.
- n_cols (int): number of columns in the plot of grids. The number
- of rows is determined automatically.
- distance_measure (str): One of "criterion", "parameter_distance".
- monotone (bool): If True the best found criterion value so far is plotted.
+ problems: optimagic benchmarking problems dictionary. Keys are the problem
+ names. Values contain information on the problem, including the solution
+ value.
+ results: optimagic benchmarking results dictionary. Keys are tuples of the form
+ (problem, algorithm), values are dictionaries of the collected information
+ on the benchmark run, including 'criterion_history' and 'time_history'.
+ problem_subset: List of problem names. These must be a subset of the keys of the
+ problems dictionary. If provided the convergence plot is only created for
+ the problems specified in this list.
+ algorithm_subset: List of algorithm names. These must be a subset of the keys of
+ the optimizer_options passed to run_benchmark. If provided only the
+ convergence of the given algorithms are shown.
+ n_cols: number of columns in the plot of grids. The number of rows is determined
+ automatically.
+ distance_measure: One of "criterion", "parameter_distance".
+ monotone: If True the best found criterion value so far is plotted.
If False the particular criterion evaluation of that time is used.
- normalize_distance (bool): If True the progress is scaled by the total distance
- between the start value and the optimal value, i.e. 1 means the algorithm
- is as far from the solution as the start value and 0 means the algorithm
- has reached the solution value.
- runtime_measure (str): "n_evaluations", "walltime" or "n_batches".
- stopping_criterion (str): "x_and_y", "x_or_y", "x", "y" or None. If None, no
- clipping is done.
- x_precision (float or None): how close an algorithm must have gotten to the
- true parameter values (as percent of the Euclidean distance between start
- and solution parameters) before the criterion for clipping and convergence
- is fulfilled.
- y_precision (float or None): how close an algorithm must have gotten to the
- true criterion values (as percent of the distance between start
- and solution criterion value) before the criterion for clipping and
- convergence is fulfilled.
- combine_plots_in_grid (bool): decide whether to return a one
- figure containing subplots for each factor pair or a dictionary
- of individual plots. Default True.
- template (str): The template for the figure. Default is "plotly_white".
- palette: The coloring palette for traces. Default is "qualitative.Plotly".
+ normalize_distance: If True the progress is scaled by the total distance between
+ the start value and the optimal value, i.e. 1 means the algorithm is as far
+ from the solution as the start value and 0 means the algorithm has reached
+ the solution value.
+ runtime_measure: This is the runtime until the desired convergence was reached
+ by an algorithm.
+ stopping_criterion: Determines how convergence is determined from the two
+ precisions. If None, no convergence criterion is applied.
+ x_precision: how close an algorithm must have gotten to the true parameter
+ values (as percent of the Euclidean distance between start and solution
+ parameters) before the criterion for clipping and convergence is fulfilled.
+ y_precision: how close an algorithm must have gotten to the true criterion
+ values (as percent of the distance between start and solution criterion
+ value) before the criterion for clipping and convergence is fulfilled.
+ combine_plots_in_grid: Whether to return a single figure containing subplots
+ for each factor pair or a dictionary of individual plots. Default is True.
+ backend: The backend to use for plotting. Default is "plotly".
+ template: The template for the figure. If not specified, the default template of
+ the backend is used.
+ palette: The coloring palette for traces. Default is the D3 qualitative palette.
Returns:
- plotly.Figure: The grid plot or dict of individual plots
+ The figure object containing the convergence plot if `combine_plots_in_grid` is
+ True. Otherwise, a dictionary mapping problem names to their respective
+ figure objects is returned.
"""
+ # ==================================================================================
+ # Process inputs
df, _ = process_benchmark_results(
problems=problems,
@@ -95,7 +144,6 @@ def convergence_plot(
y_precision=y_precision,
)
- # handle string provision for single problems / algorithms
if isinstance(problem_subset, str):
problem_subset = [problem_subset]
if isinstance(algorithm_subset, str):
@@ -109,137 +157,151 @@ def convergence_plot(
if algorithm_subset is not None:
df = df[df["algorithm"].isin(algorithm_subset)]
- # plot configuration
+ # ==================================================================================
+ # Extract backend-agnostic plotting data
+
outcome = (
f"{'monotone_' if monotone else ''}"
+ distance_measure
+ f"{'_normalized' if normalize_distance else ''}"
)
- remaining_problems = df["problem"].unique()
- n_rows = int(np.ceil(len(remaining_problems) / n_cols))
-
- # pre - style plots labels
- y_labels = {
- "criterion": "Current Function Value",
- "monotone_criterion": "Best Function Value Found So Far",
- "criterion_normalized": "Share of Function Distance to Optimum
"
- "Missing From Current Criterion Value",
- "monotone_criterion_normalized": "Share of Function Distance to Optimum
"
- "Missing From Best So Far",
- "parameter_distance": "Distance Between Current and Optimal Parameters",
- "parameter_distance_normalized": "Share of Parameter Distance to Optimum
"
- "Missing From Current Parameters",
- "monotone_parameter_distance_normalized": "Share of the Parameter Distance "
- "to Optimum
Missing From the Best Parameters So Far",
- "monotone_parameter_distance": "Distance Between the Best Parameters So Far
"
- "and the Optimal Parameters",
- }
-
- x_labels = {
- "n_evaluations": "Number of Function Evaluations",
- "walltime": "Elapsed Time",
- "n_batches": "Number of Batches",
- }
-
- # container for individual plots
- g_list = []
- # container for titles
+ lines_list, titles = _extract_convergence_plot_lines(
+ df=df,
+ problems=problems,
+ runtime_measure=runtime_measure,
+ outcome=outcome,
+ palette=palette,
+ combine_plots_in_grid=combine_plots_in_grid,
+ )
+
+ n_rows = int(np.ceil(len(lines_list) / n_cols))
+
+ # ==================================================================================
+ # Generate the figure
+
+ if combine_plots_in_grid:
+ fig = grid_line_plot(
+ lines_list,
+ backend=backend,
+ n_rows=n_rows,
+ n_cols=n_cols,
+ titles=titles,
+ xlabel=RUNTIME_MEASURE_TO_CONVERGENCE_PLOT_XLABEL[runtime_measure],
+ ylabel=OUTCOME_TO_CONVERGENCE_PLOT_YLABEL[outcome],
+ template=template,
+ height=320 * n_rows,
+ width=500 * n_cols,
+ legend_properties=BACKEND_TO_CONVERGENCE_PLOT_LEGEND_PROPERTIES.get(
+ backend, None
+ ),
+ margin_properties=BACKEND_TO_CONVERGENCE_PLOT_MARGIN_PROPERTIES.get(
+ backend, None
+ ),
+ )
+
+ return fig
+
+ else:
+ fig_dict = {}
+
+ for i, subplot_lines in enumerate(lines_list):
+ fig = line_plot(
+ subplot_lines,
+ backend=backend,
+ title=titles[i],
+ xlabel=RUNTIME_MEASURE_TO_CONVERGENCE_PLOT_XLABEL[runtime_measure],
+ ylabel=OUTCOME_TO_CONVERGENCE_PLOT_YLABEL[outcome],
+ template=template,
+ height=320,
+ width=500,
+ legend_properties=BACKEND_TO_CONVERGENCE_PLOT_LEGEND_PROPERTIES.get(
+ backend, None
+ ),
+ margin_properties=BACKEND_TO_CONVERGENCE_PLOT_MARGIN_PROPERTIES.get(
+ backend, None
+ ),
+ )
+
+ key = titles[i].replace(" ", "_").lower()
+ fig_dict[key] = fig
+
+ return fig_dict
+
+
+def _extract_convergence_plot_lines(
+ df: pd.DataFrame,
+ problems: dict[str, dict[str, Any]],
+ runtime_measure: str,
+ outcome: str,
+ palette: list[str] | str,
+ combine_plots_in_grid: bool = True,
+) -> tuple[list[list[LineData]], list[str]]:
+ lines_list = [] # container for all subplots
titles = []
- # creating data traces for plotting faceted/individual plots
- # dropping usage of palette for algoritms, but use the built in pallete
- for prob_name in remaining_problems:
- g_ind = [] # container for data for traces in individual plot
- to_plot = df[df["problem"] == prob_name]
+ for i, (_prob_name, _prob_data) in enumerate(df.groupby("problem", sort=False)):
+ prob_name = str(_prob_name)
+ subplot_lines = [] # container for data of traces in individual subplot
+ palette_cycle = get_palette_cycle(palette)
+
if runtime_measure == "n_batches":
to_plot = (
- to_plot.groupby(["algorithm", runtime_measure]).min().reset_index()
+ _prob_data.groupby(["algorithm", runtime_measure]).min().reset_index()
)
-
- for i, alg in enumerate(to_plot["algorithm"].unique()):
- temp = to_plot[to_plot["algorithm"] == alg]
- trace_1 = go.Scatter(
- x=temp[runtime_measure],
- y=temp[outcome],
- mode="lines",
- legendgroup=i,
- name=alg,
- line={"color": palette[i]},
+ else:
+ to_plot = _prob_data
+
+ for alg, group in to_plot.groupby("algorithm", sort=False):
+ line_data = LineData(
+ x=group[runtime_measure].to_numpy(),
+ y=group[outcome].to_numpy(),
+ name=str(alg),
+ color=next(palette_cycle),
+ # if combining plots, only show legend in first subplot
+ show_in_legend=(not combine_plots_in_grid) or (i == 0),
)
- g_ind.append(trace_1)
+ subplot_lines.append(line_data)
- if distance_measure == "criterion" and not normalize_distance:
+ if outcome in ("criterion", "monotone_criterion"):
f_opt = problems[prob_name]["solution"]["value"]
- trace_2 = go.Scatter(
- y=[f_opt for i in to_plot[runtime_measure]],
- x=to_plot[runtime_measure],
- mode="lines",
- line={"color": palette[i + 1]},
+ line_data = LineData(
+ x=to_plot[runtime_measure].to_numpy(),
+ y=np.full(to_plot[runtime_measure].shape, f_opt),
name="true solution",
- legendgroup=i + 1,
+ color=next(palette_cycle),
+ # if combining plots, only show legend in first subplot
+ show_in_legend=(not combine_plots_in_grid) or (i == 0),
)
- g_ind.append(trace_2)
+ subplot_lines.append(line_data)
- g_list.append(g_ind)
+ lines_list.append(subplot_lines)
titles.append(prob_name.replace("_", " ").title())
- xaxis_title = [x_labels[runtime_measure] for ind in range(len(g_list))]
- yaxis_title = [y_labels[outcome] for ind in range(len(g_list))]
-
- common_dependencies = {
- "ind_list": g_list,
- "names": titles,
- "clean_legend": True,
- "x_title": xaxis_title,
- "y_title": yaxis_title,
- }
- common_layout = {
- "template": template,
- "margin": {"l": 10, "r": 10, "t": 30, "b": 10},
- }
-
- # Plot with subplots
- if combine_plots_in_grid:
- g = create_grid_plot(
- rows=n_rows,
- cols=n_cols,
- **common_dependencies,
- kws={"height": 320 * n_rows, "width": 500 * n_cols, **common_layout},
- )
- out = g
-
- # Dictionary for individual plots
- else:
- ind_dict = create_ind_dict(
- **common_dependencies,
- kws={"height": 320, "width": 500, "title_x": 0.5, **common_layout},
- )
-
- out = ind_dict
-
- return out
+ return lines_list, titles
-def _check_only_allowed_subset_provided(subset, allowed, name):
+def _check_only_allowed_subset_provided(
+ subset: list[str] | None, allowed: pd.Series | list[str], name: str
+) -> None:
"""Check if all entries of a proposed subset are in a Series.
Args:
- subset (iterable or None): If None, no checks are performed. Else a ValueError
- is raised listing all entries that are not in the provided Series.
- allowed (iterable): allowed entries.
- name (str): name of the provided entries to use for the ValueError.
+ subset: If None, no checks are performed. Else a ValueError is raised listing
+ all entries that are not in the provided Series.
+ allowed: allowed entries.
+ name: name of the provided entries to use for the ValueError.
Raises:
ValueError
"""
- allowed = set(allowed)
+ allowed_set = set(allowed)
if subset is not None:
- missing = [entry for entry in subset if entry not in allowed]
+ missing = [entry for entry in subset if entry not in allowed_set]
if missing:
missing_msg = ""
for entry in missing:
- proposed = propose_alternatives(entry, allowed)
+ proposed = propose_alternatives(entry, allowed_set)
missing_msg += f"Invalid {name}: {entry}. Did you mean {proposed}?\n"
raise ValueError(missing_msg)
diff --git a/src/optimagic/visualization/profile_plot.py b/src/optimagic/visualization/profile_plot.py
index fa9629696..8d828b6a3 100644
--- a/src/optimagic/visualization/profile_plot.py
+++ b/src/optimagic/visualization/profile_plot.py
@@ -15,8 +15,7 @@
BACKEND_TO_PROFILE_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
"plotly": {"title": {"text": "algorithm"}},
"matplotlib": {
- "bbox_to_anchor": (1.02, 1),
- "loc": "upper left",
+ "loc": "outside right upper",
"fontsize": "x-small",
"title": "algorithm",
},
@@ -24,7 +23,7 @@
BACKEND_TO_PROFILE_PLOT_MARGIN_PROPERTIES: dict[str, dict[str, Any]] = {
"plotly": {"l": 10, "r": 10, "t": 30, "b": 30},
- # "matplotlib": handles margins automatically via tight_layout()
+ # "matplotlib": handles margins automatically via constrained layout
}
diff --git a/tests/optimagic/visualization/test_backends.py b/tests/optimagic/visualization/test_backends.py
index 1e2b7fd7f..263895f29 100644
--- a/tests/optimagic/visualization/test_backends.py
+++ b/tests/optimagic/visualization/test_backends.py
@@ -31,7 +31,7 @@ def test_line_plot_invalid_backend(sample_lines):
def test_line_plot_unavailable_backend(sample_lines, monkeypatch):
# Use monkeypatch to simulate that 'matplotlib' backend is not installed.
monkeypatch.setitem(
- BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION, "matplotlib", (False, None)
+ BACKEND_AVAILABILITY_AND_LINE_PLOT_FUNCTION, "matplotlib", (False, None, None)
)
with pytest.raises(NotInstalledError):
diff --git a/tests/optimagic/visualization/test_convergence_plot.py b/tests/optimagic/visualization/test_convergence_plot.py
index c6d119573..a5a77fe04 100644
--- a/tests/optimagic/visualization/test_convergence_plot.py
+++ b/tests/optimagic/visualization/test_convergence_plot.py
@@ -1,12 +1,44 @@
import pytest
from optimagic import get_benchmark_problems
+from optimagic.benchmarking.process_benchmark_results import process_benchmark_results
from optimagic.benchmarking.run_benchmark import run_benchmark
from optimagic.visualization.convergence_plot import (
_check_only_allowed_subset_provided,
+ _extract_convergence_plot_lines,
convergence_plot,
)
+
+@pytest.fixture()
+def benchmark_results():
+ problems = get_benchmark_problems("example")
+ stop_after_10 = {
+ "stopping_max_criterion_evaluations": 10,
+ "stopping_max_iterations": 10,
+ }
+ optimizers = {
+ "lbfgsb": {"algorithm": "scipy_lbfgsb", "algo_options": stop_after_10},
+ "nm": {"algorithm": "scipy_neldermead", "algo_options": stop_after_10},
+ }
+ results = run_benchmark(
+ problems,
+ optimizers,
+ n_cores=1, # must be 1 for the test to work
+ )
+ return problems, results
+
+
+def test_convergence_plot_default_options(benchmark_results):
+ problems, results = benchmark_results
+
+ convergence_plot(
+ problems=problems,
+ results=results,
+ problem_subset=["bard_good_start"],
+ )
+
+
# integration test to make sure non default argument do not throw Errors
profile_options = [
{"n_cols": 3},
@@ -15,33 +47,19 @@
{"normalize_distance": False},
{"runtime_measure": "walltime"},
{"runtime_measure": "n_batches"},
- {"stopping_criterion": None},
{"stopping_criterion": "x"},
{"stopping_criterion": "x_and_y"},
{"stopping_criterion": "x_or_y"},
{"x_precision": 1e-5},
{"y_precision": 1e-5},
+ {"backend": "matplotlib"},
]
-@pytest.mark.parametrize(
- "options, grid", zip(profile_options, [True, False], strict=False)
-)
-def test_convergence_plot_options(options, grid):
- problems = get_benchmark_problems("example")
- stop_after_10 = {
- "stopping_max_criterion_evaluations": 10,
- "stopping_max_iterations": 10,
- }
- optimizers = {
- "lbfgsb": {"algorithm": "scipy_lbfgsb", "algo_options": stop_after_10},
- "nm": {"algorithm": "scipy_neldermead", "algo_options": stop_after_10},
- }
- results = run_benchmark(
- problems,
- optimizers,
- n_cores=1, # must be 1 for the test to work
- )
+@pytest.mark.parametrize("options", profile_options)
+@pytest.mark.parametrize("grid", [True, False])
+def test_convergence_plot_options(options, grid, benchmark_results):
+ problems, results = benchmark_results
convergence_plot(
problems=problems,
@@ -52,6 +70,18 @@ def test_convergence_plot_options(options, grid):
)
+def test_convergence_plot_stopping_criterion_none(benchmark_results):
+ problems, results = benchmark_results
+
+ with pytest.raises(UnboundLocalError):
+ convergence_plot(
+ problems=problems,
+ results=results,
+ problem_subset=["bard_good_start"],
+ stopping_criterion=None,
+ )
+
+
def test_check_only_allowed_subset_provided_none():
allowed = ["a", "b", "c"]
_check_only_allowed_subset_provided(None, allowed, "name")
@@ -66,3 +96,29 @@ def test_check_only_allowed_subset_provided_missing():
allowed = ["a", "b", "c"]
with pytest.raises(ValueError):
_check_only_allowed_subset_provided(["d"], allowed, "name")
+
+
+def test_extract_convergence_plot_lines(benchmark_results):
+ problems, results = benchmark_results
+
+ df, _ = process_benchmark_results(
+ problems=problems, results=results, stopping_criterion="y"
+ )
+
+ lines_list, titles = _extract_convergence_plot_lines(
+ df=df,
+ problems=problems,
+ runtime_measure="n_evaluations",
+ outcome="criterion_normalized",
+ palette=["red", "green", "blue"],
+ )
+
+ assert isinstance(lines_list, list) and isinstance(titles, list)
+ assert len(lines_list) == len(titles) == len(problems)
+
+ for subplot_lines in lines_list:
+ assert isinstance(subplot_lines, list) and len(subplot_lines) == 2
+ assert subplot_lines[0].name == "lbfgsb"
+ assert subplot_lines[1].name == "nm"
+ assert subplot_lines[0].color == "red"
+ assert subplot_lines[1].color == "green"