import warnings
from functools import partial
from typing import Any, Callable, Literal
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from pybaum import tree_just_flatten
import optimagic as om
from optimagic import deprecations
from optimagic.batch_evaluators import (
BatchEvaluator,
BatchEvaluatorLiteral,
process_batch_evaluator,
)
from optimagic.config import DEFAULT_N_CORES, DEFAULT_PALETTE
from optimagic.deprecations import replace_and_warn_about_deprecated_bounds
from optimagic.optimization.fun_value import (
SpecificFunctionValue,
convert_fun_output_to_function_value,
enforce_return_type,
)
from optimagic.parameters.bounds import pre_process_bounds
from optimagic.parameters.conversion import get_converter
from optimagic.parameters.space_conversion import InternalParams
from optimagic.parameters.tree_registry import get_registry
from optimagic.shared.process_user_function import infer_aggregation_level
from optimagic.typing import AggregationLevel, PyTree
from optimagic.visualization.backends import grid_line_plot, line_plot
from optimagic.visualization.plotting_utilities import LineData, MarkerData
[docs]
def slice_plot(
func: Callable,
params: PyTree,
bounds: om.Bounds | None = None,
func_kwargs: dict | None = None,
selector: Callable[[PyTree], PyTree] | None = None,
n_cores: int = DEFAULT_N_CORES,
n_gridpoints: int = 20,
plots_per_row: int = 2,
param_names: dict[str, str] | None = None,
share_y: bool = True,
expand_yrange: float = 0.02,
share_x: bool = False,
backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly",
template: str | None = None,
color: str | None = DEFAULT_PALETTE[0],
title: str | None = None,
return_dict: bool = False,
batch_evaluator: BatchEvaluatorLiteral | BatchEvaluator = "joblib",
# deprecated
make_subplot_kwargs: dict | None = None,
lower_bounds: None = None,
upper_bounds: None = None,
) -> Any:
"""Plot criterion along coordinates at given and random values.
Generates plots for each parameter and optionally combines them into a figure
with subplots.
# TODO: Use soft bounds to create the grid (if available).
# TODO: Don't do a function evaluation outside the batch evaluator.
Args:
func: criterion function that takes params and returns scalar, PyTree or
FunctionValue object.
params: A pytree with parameters.
bounds: Lower and upper bounds on the parameters. The bounds are used to create
a grid over which slice plots are drawn. The most general and preferred
way to specify bounds is an `optimagic.Bounds` object that collects lower,
upper, soft_lower and soft_upper bounds. The soft bounds are not used for
slice_plots. Each bound type mirrors the structure of params. Check our
how-to guide on bounds for examples. If params is a flat numpy array, you
can also provide bounds via any format that is supported by
scipy.optimize.minimize.
func_kwargs: Additional keyword arguments passed to func.
selector: Function that takes params and returns a subset of params for which we
actually want to generate the plot.
n_cores: Number of cores.
n_gridpoints: Number of gridpoints on which the criterion function is evaluated.
This is the number per plotted line.
plots_per_row: Number of plots per row.
param_names: Dictionary mapping old parameter names to new ones.
share_y: If True, the individual plots share the scale on the yaxis and plots in
one row actually share the y axis.
expand_yrange: The ratio by which to expand the range of the (shared) y axis,
such that the axis is not cropped at exactly max of Criterion Value.
share_x: If True, set the same range of x axis for all plots and share the
x axis for all plots in one column.
backend: The backend to use for plotting. Default is "plotly".
template: The template for the figure. If not specified, the default template of
the backend is used. For the 'bokeh' and 'altair' backends, this changes the
global theme, which affects all plots from that backend in the session.
color: The line color.
title: The figure title. This is not used for the `bokeh` backend, as it does
not support title for grid plot.
return_dict: If True, return dictionary with individual plots of each parameter,
else, combine individual plots into a figure with subplots.
batch_evaluator: See :ref:`batch_evaluators`.
Returns:
The figure object containing the slice plot if `return_dict` is False.
Otherwise, a dictionary with individual slice plots for each parameter.
"""
# ==================================================================================
# Process inputs
bounds = replace_and_warn_about_deprecated_bounds(
lower_bounds=lower_bounds,
upper_bounds=upper_bounds,
bounds=bounds,
)
bounds = pre_process_bounds(bounds)
func, func_eval = _get_processed_func_and_func_eval(func, func_kwargs, params)
if make_subplot_kwargs is not None:
deprecations.throw_make_subplot_kwargs_in_slice_plot_future_warning()
# ==================================================================================
# Extract backend-agnostic plotting data from results
plot_data, internal_params = _get_plot_data(
func=func,
params=params,
bounds=bounds,
func_eval=func_eval,
selector=selector,
n_gridpoints=n_gridpoints,
batch_evaluator=batch_evaluator,
n_cores=n_cores,
)
lines_list, marker_list, xlabels, ylabels = _extract_slice_plot_lines_and_labels(
plot_data=plot_data,
internal_params=internal_params,
func_eval=func_eval,
param_names=param_names,
color=color,
)
# ==================================================================================
# Generate the figure
xrange, yrange = _get_axis_limits(
plot_data, share_y=share_y, share_x=share_x, expand_yrange=expand_yrange
)
if return_dict:
fig_dict = {}
for i in range(len(lines_list)):
fig = line_plot(
lines=lines_list[i],
marker=marker_list[i],
backend=backend,
xlabel=xlabels[i],
ylabel=ylabels[i],
template=template,
)
fig_dict[xlabels[i]] = fig
return fig_dict
else:
n_rows = int(np.ceil(len(lines_list) / plots_per_row))
if share_y:
ylabels = [
ylabel if i % plots_per_row == 0 else ""
for i, ylabel in enumerate(ylabels)
]
fig = grid_line_plot(
lines_list=lines_list,
marker_list=marker_list,
backend=backend,
n_rows=n_rows,
n_cols=plots_per_row,
xlabels=xlabels,
xrange=xrange,
share_x=share_x,
ylabels=ylabels,
yrange=yrange,
share_y=share_y,
template=template,
height=300 * n_rows,
width=400 * plots_per_row,
plot_title=title,
make_subplot_kwargs=make_subplot_kwargs,
)
return fig
def _get_processed_func_and_func_eval(
func: Callable, func_kwargs: dict | None, params: PyTree
) -> tuple[Callable, SpecificFunctionValue]:
if func_kwargs is not None:
func = partial(func, **func_kwargs)
func_eval = func(params)
# handle deprecated function output
if deprecations.is_dict_output(func_eval):
msg = (
"Functions that return dictionaries are deprecated in slice_plot and will "
"raise an error in version 0.6.0. Please pass a function that returns a "
"FunctionValue object instead and use the `mark` decorators to specify "
"whether it is a scalar, least-squares or likelihood function."
)
warnings.warn(msg, FutureWarning)
func_eval = deprecations.convert_dict_to_function_value(func_eval)
func = deprecations.replace_dict_output(func)
# Infer the function type and enforce the return type
if deprecations.is_dict_output(func_eval):
problem_type = deprecations.infer_problem_type_from_dict_output(func_eval)
else:
problem_type = infer_aggregation_level(func)
func_eval = convert_fun_output_to_function_value(func_eval, problem_type)
func = enforce_return_type(problem_type)(func)
return func, func_eval
def _get_plot_data(
func: Callable,
params: PyTree,
bounds: om.Bounds | None,
func_eval: SpecificFunctionValue,
selector: Callable[[PyTree], PyTree] | None,
n_gridpoints: int,
batch_evaluator: BatchEvaluatorLiteral | BatchEvaluator,
n_cores: int,
) -> tuple[pd.DataFrame, InternalParams]:
converter, internal_params = get_converter(
params=params,
constraints=None,
bounds=bounds,
func_eval=func_eval,
solver_type="value",
)
n_params = len(internal_params.values)
selected = np.arange(n_params, dtype=int)
if selector is not None:
helper = converter.params_from_internal(selected)
registry = get_registry(extended=True)
selected = np.array(
tree_just_flatten(selector(helper), registry=registry), dtype=int
).ravel() # Ensure the result is a 1D array
if not np.isfinite(internal_params.lower_bounds[selected]).all():
raise ValueError("All selected parameters must have finite lower bounds.")
if not np.isfinite(internal_params.upper_bounds[selected]).all():
raise ValueError("All selected parameters must have finite upper bounds.")
evaluation_points, metadata = [], []
for pos in selected:
lb = internal_params.lower_bounds[pos]
ub = internal_params.upper_bounds[pos]
grid = np.linspace(lb, ub, n_gridpoints)
name = internal_params.names[pos]
for param_value in grid:
if param_value != internal_params.values[pos]:
meta = {
"name": name,
"Parameter Value": param_value,
}
x = internal_params.values.copy()
x[pos] = param_value
point = converter.params_from_internal(x)
evaluation_points.append(point)
metadata.append(meta)
func_values = _retrieve_func_values(
func, evaluation_points, batch_evaluator, n_cores
)
func_values += [func_eval.internal_value(AggregationLevel.SCALAR)] * len(selected)
for pos in selected:
meta = {
"name": internal_params.names[pos],
"Parameter Value": internal_params.values[pos],
}
metadata.append(meta)
plot_data = pd.DataFrame(metadata)
plot_data["Function Value"] = func_values # type: ignore[assignment]
return plot_data, internal_params
def _retrieve_func_values(
func: Callable,
evaluation_points: list[PyTree],
batch_evaluator: BatchEvaluatorLiteral | BatchEvaluator,
n_cores: int,
) -> list[float | NDArray[np.float64]]:
"""Retrieve function values at given evaluation points using batch evaluator."""
batch_evaluator = process_batch_evaluator(batch_evaluator)
func_values = batch_evaluator(
func=func,
arguments=evaluation_points,
error_handling="continue",
n_cores=n_cores,
)
# add NaNs where an evaluation failed
func_values = [
np.nan if isinstance(val, str) else val.internal_value(AggregationLevel.SCALAR)
for val in func_values
]
return func_values
def _extract_slice_plot_lines_and_labels(
plot_data: pd.DataFrame,
internal_params: InternalParams,
func_eval: SpecificFunctionValue,
param_names: dict[str, str] | None,
color: str | None,
) -> tuple[list[list[LineData]], list[MarkerData], list[str], list[str]]:
"""Extract lines, markers and labels for slice plots."""
lines_list = []
marker_list = []
xlabels = []
ylabels = []
for _par_name, _data in plot_data.groupby("name", sort=False):
df = _data.sort_values("Parameter Value")
par_name = str(_par_name)
if param_names is not None and par_name in param_names:
par_name = param_names[par_name]
subplot_line = LineData(
x=df["Parameter Value"].to_numpy(),
y=df["Function Value"].to_numpy(),
color=color,
name=par_name,
show_in_legend=False,
)
lines_list.append([subplot_line])
if internal_params.names is not None:
pos = internal_params.names.index(_par_name)
marker_data = MarkerData(
x=float(internal_params.values[pos]),
y=float(func_eval.internal_value(AggregationLevel.SCALAR)),
color=color,
)
marker_list.append(marker_data)
xlabels.append(par_name)
ylabels.append("Function Value")
return lines_list, marker_list, xlabels, ylabels
def _get_axis_limits(
plot_data: pd.DataFrame, share_y: bool, share_x: bool, expand_yrange: float
) -> tuple[tuple[float, float] | None, tuple[float, float] | None]:
if share_y:
lb = plot_data["Function Value"].min()
ub = plot_data["Function Value"].max()
y_range = ub - lb
ub += y_range * expand_yrange
lb -= y_range * expand_yrange
yrange = (lb, ub)
else:
yrange = None
if share_x:
lb = plot_data["Parameter Value"].min()
ub = plot_data["Parameter Value"].max()
xrange = (lb, ub)
else:
xrange = None
return xrange, yrange