Source code for optimagic.visualization.history_plots

import inspect
import itertools
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal

import numpy as np
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import DEFAULT_PALETTE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history import History
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
from optimagic.typing import IterationHistory, PyTree
from optimagic.visualization.backends import line_plot
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle

BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
    "plotly": {
        "yanchor": "top",
        "xanchor": "right",
        "y": 0.95,
        "x": 0.95,
    },
    "matplotlib": {
        "loc": "upper right",
    },
    "bokeh": {
        "location": "top_right",
    },
    "altair": {
        "orient": "top-right",
    },
}


ResultOrPath = OptimizeResult | str | Path


[docs] def criterion_plot( results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None = None, max_evaluations: int | None = None, backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, stack_multistart: bool = False, monotone: bool = False, show_exploration: bool = False, ) -> Any: """Plot the criterion history of an optimization. Args: results: An optimization result (or list of, or dict of results) with collected history, or path(s) to it. If dict, then the key is used as the name in the legend. max_evaluations: Clip the criterion history after that many entries. backend: The backend to use for plotting. Default is "plotly". template: The template for the figure. If not specified, the default template of the backend is used. For the 'bokeh' and 'altair' backends, this changes the global theme, which affects all plots from that backend in the session. palette: The coloring palette for traces. Default is the D3 qualitative palette. stack_multistart: Whether to combine multistart histories into a single history. Default is False. monotone: If True, the criterion plot becomes monotone in the sense that at each iteration the current best criterion value is displayed. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: The figure object containing the criterion plot. """ # ================================================================================== # Process inputs palette_cycle = get_palette_cycle(palette) dict_of_optimize_results_or_paths = _harmonize_inputs_to_dict(results, names) # ================================================================================== # Extract backend-agnostic plotting data from results list_of_optimize_data = _retrieve_optimization_data_from_results( results=dict_of_optimize_results_or_paths, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name="criterion_plot", ) lines, multistart_lines = _extract_criterion_plot_lines( data=list_of_optimize_data, max_evaluations=max_evaluations, palette_cycle=palette_cycle, stack_multistart=stack_multistart, monotone=monotone, ) # ================================================================================== # Generate the figure fig = line_plot( lines=multistart_lines + lines, backend=backend, xlabel="No. of criterion evaluations", ylabel="Criterion value", template=template, legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig
def _harmonize_inputs_to_dict( results: ResultOrPath | list[ResultOrPath] | dict[str, ResultOrPath], names: list[str] | str | None, ) -> dict[str, ResultOrPath]: """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" # convert scalar case to list case if not isinstance(names, list) and names is not None: names = [names] if isinstance(results, (OptimizeResult, str, Path)): results = [results] if names is not None and len(names) != len(results): raise ValueError("len(results) needs to be equal to len(names).") # handle dict case if isinstance(results, dict): if names is not None: results_dict = dict(zip(names, list(results.values()), strict=False)) else: results_dict = results # unlabeled iterable of results else: if names is None: names = [str(i) for i in range(len(results))] results_dict = dict(zip(names, results, strict=False)) # convert keys to strings results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()} return results_dict def _convert_key_to_str(key: Any) -> str: if inspect.isclass(key) and issubclass(key, Algorithm): out = str(key.name) elif isinstance(key, Algorithm): out = str(key.name) else: out = str(key) return out
[docs] def params_plot( result: ResultOrPath, selector: Callable[[PyTree], PyTree] | None = None, max_evaluations: int | None = None, backend: Literal["plotly", "matplotlib", "bokeh", "altair"] = "plotly", template: str | None = None, palette: list[str] | str = DEFAULT_PALETTE, show_exploration: bool = False, ) -> Any: """Plot the params history of an optimization. Args: result: An optimization result with collected history, or path to it. If dict, then the key is used as the name in the legend. selector: A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations: Clip the criterion history after that many entries. backend: The backend to use for plotting. Default is "plotly". template: The template for the figure. If not specified, the default template of the backend is used. For the 'bokeh' and 'altair' backends, this changes the global theme, which affects all plots from that backend in the session. palette: The coloring palette for traces. Default is the D3 qualitative palette. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. Returns: The figure object containing the params plot. """ # ================================================================================== # Process inputs palette_cycle = get_palette_cycle(palette) # ================================================================================== # Extract backend-agnostic plotting data from results optimize_data = _retrieve_optimization_data_from_single_result( result=result, stack_multistart=True, show_exploration=show_exploration, plot_name="params_plot", ) lines = _extract_params_plot_lines( data=optimize_data, selector=selector, max_evaluations=max_evaluations, palette_cycle=palette_cycle, ) # ================================================================================== # Generate the figure fig = line_plot( lines=lines, backend=backend, xlabel="No. of criterion evaluations", ylabel="Parameter value", template=template, legend_properties=BACKEND_TO_HISTORY_PLOT_LEGEND_PROPERTIES.get(backend, None), ) return fig
@dataclass(frozen=True) class _PlottingMultistartHistory: """Data container for an optimization history and metadata. Contains local histories in case of multistart optimization. This dataclass is only used internally. """ history: History name: str | None start_params: PyTree is_multistart: bool local_histories: list[History] | list[IterationHistory] | None stacked_local_histories: History | None def _retrieve_optimization_data_from_results( results: dict[str, ResultOrPath], stack_multistart: bool, show_exploration: bool, plot_name: str, ) -> list[_PlottingMultistartHistory]: # Retrieves data from multiple results by iterating over the results dictionary # and calling the single result retrieval function. data = [] for name, res in results.items(): _data = _retrieve_optimization_data_from_single_result( result=res, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name=plot_name, res_name=name, ) data.append(_data) return data def _retrieve_optimization_data_from_single_result( result: ResultOrPath, stack_multistart: bool, show_exploration: bool, plot_name: str, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve data from a single result (OptimizeResult or database). Args: result: An optimization result with collected history, or path to it. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. plot_name: Name of the plotting function that calls this function. Used for raising errors. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ if isinstance(result, OptimizeResult): data = _retrieve_optimization_data_from_result_object( res=result, stack_multistart=stack_multistart, show_exploration=show_exploration, plot_name=plot_name, res_name=res_name, ) elif isinstance(result, (str, Path)): data = _retrieve_optimization_data_from_database( res=result, stack_multistart=stack_multistart, show_exploration=show_exploration, res_name=res_name, ) else: msg = ( "result must be an OptimizeResult or a path to a log file, " f"but is type {type(result)}." ) raise TypeError(msg) return data def _retrieve_optimization_data_from_result_object( res: OptimizeResult, stack_multistart: bool, show_exploration: bool, plot_name: str, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve optimization data from result object. Args: res: An optimization result object. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. plot_name: Name of the plotting function that calls this function. Used for raising errors. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ if res.history is None: msg = f"{plot_name} requires an optimize result with history. Enable history " "collection by setting collect_history=True when calling maximize or minimize." raise ValueError(msg) if res.multistart_info: local_histories = [ opt.history for opt in res.multistart_info.local_optima if opt.history is not None ] if stack_multistart: stacked = _get_stacked_local_histories(local_histories, res.direction) if show_exploration: fun = res.multistart_info.exploration_results[::-1] + stacked.fun params = res.multistart_info.exploration_sample[::-1] + stacked.params stacked = History( direction=stacked.direction, fun=fun, params=params, # TODO: This needs to be fixed start_time=len(fun) * [None], # type: ignore stop_time=len(fun) * [None], # type: ignore batches=len(fun) * [None], # type: ignore task=len(fun) * [None], # type: ignore ) else: stacked = None else: local_histories = None stacked = None data = _PlottingMultistartHistory( history=res.history, name=res_name, start_params=res.start_params, is_multistart=res.multistart_info is not None, local_histories=local_histories, stacked_local_histories=stacked, ) return data def _retrieve_optimization_data_from_database( res: str | Path, stack_multistart: bool, show_exploration: bool, res_name: str | None = None, ) -> _PlottingMultistartHistory: """Retrieve optimization data from a database. Args: res: A path to an optimization database. stack_multistart: Whether to combine multistart histories into a single history. Default is False. show_exploration: If True, exploration samples of a multistart optimization are visualized. Default is False. res_name: Name of the result. Returns: A data object containing the history, metadata, and local histories of the optimization result. """ reader: LogReader = LogReader.from_options(SQLiteLogOptions(res)) _problem_table = reader.problem_df direction = _problem_table["direction"].tolist()[-1] multistart_history = reader.read_multistart_history(direction) _history = multistart_history.history local_histories = multistart_history.local_histories exploration = multistart_history.exploration if stack_multistart and local_histories is not None: stacked = _get_stacked_local_histories(local_histories, direction, _history) if show_exploration: stacked["params"] = exploration["params"][::-1] + stacked["params"] # type: ignore stacked["criterion"] = exploration["criterion"][::-1] + stacked["criterion"] # type: ignore else: stacked = None history = History( direction=direction, fun=_history["fun"], params=_history["params"], start_time=_history["time"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(_history["fun"]) * [None], # type: ignore task=len(_history["fun"]) * [None], # type: ignore batches=list(range(len(_history["fun"]))), ) data = _PlottingMultistartHistory( history=history, name=res_name, start_params=reader.read_start_params(), is_multistart=local_histories is not None, local_histories=local_histories, stacked_local_histories=stacked, ) return data def _get_stacked_local_histories( local_histories: list[History] | list[IterationHistory], direction: Any, history: History | IterationHistory | None = None, ) -> History: """Stack local histories. Local histories is a list of dictionaries, each of the same structure. We transform this to a dictionary of lists. Finally, when the data is read from the database we append the best history at the end. """ stacked: dict[str, list[Any]] = {"criterion": [], "params": [], "runtime": []} for hist in local_histories: stacked["criterion"].extend(hist.fun) stacked["params"].extend(hist.params) stacked["runtime"].extend(hist.time) # append additional history is necessary if history is not None: stacked["criterion"].extend(history.fun) stacked["params"].extend(history.params) stacked["runtime"].extend(history.time) return History( direction=direction, fun=stacked["criterion"], params=stacked["params"], start_time=stacked["runtime"], # TODO (@janosg): Retrieve `stop_time` from `hist` once it is available for the # IterationHistory. # https://github.com/optimagic-dev/optimagic/pull/553 stop_time=len(stacked["criterion"]) * [None], # type: ignore task=len(stacked["criterion"]) * [None], # type: ignore batches=list(range(len(stacked["criterion"]))), ) def _extract_criterion_plot_lines( data: list[_PlottingMultistartHistory], max_evaluations: int | None, palette_cycle: "itertools.cycle[str]", stack_multistart: bool, monotone: bool, ) -> tuple[list[LineData], list[LineData]]: """Extract lines for criterion plot from data. Args: data: Data retrieved from results or database. max_evaluations: Clip the criterion history after that many entries. palette_cycle: Cycle of colors for plotting. stack_multistart: Whether to combine multistart histories into a single history. Default is False. monotone: If True, the criterion plot becomes monotone in the sense that at each iteration the current best criterion value is displayed. Returns: Tuple containing - lines: Main optimization paths. - multistart_lines: Multistart optimization paths. """ fun_or_monotone_fun = "monotone_fun" if monotone else "fun" # Collect multistart optimization paths multistart_lines: list[LineData] = [] plot_multistart = len(data) == 1 and data[0].is_multistart and not stack_multistart if plot_multistart and data[0].local_histories: for i, local_history in enumerate(data[0].local_histories): history = getattr(local_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] line_data = LineData( x=np.arange(len(history)), y=history, color="#bab0ac", name=str(i), show_in_legend=False, ) multistart_lines.append(line_data) # Collect main optimization paths lines: list[LineData] = [] for _data in data: if stack_multistart and _data.stacked_local_histories is not None: _history = _data.stacked_local_histories else: _history = _data.history history = getattr(_history, fun_or_monotone_fun) if max_evaluations is not None and len(history) > max_evaluations: history = history[:max_evaluations] line_data = LineData( x=np.arange(len(history)), y=history, color=next(palette_cycle), name="best result" if plot_multistart else _data.name, show_in_legend=not plot_multistart, ) lines.append(line_data) return lines, multistart_lines def _extract_params_plot_lines( data: _PlottingMultistartHistory, selector: Callable[[PyTree], PyTree] | None, max_evaluations: int | None, palette_cycle: "itertools.cycle[str]", ) -> list[LineData]: """Extract lines for params plot from data. Args: data: Data retrieved from results or database. selector: A callable that takes params and returns a subset of params. If provided, only the selected subset of params is plotted. max_evaluations: Clip the criterion history after that many entries. palette_cycle: Cycle of colors for plotting. Returns: lines: Parameter histories. """ if data.stacked_local_histories is not None: history = data.stacked_local_histories.params else: history = data.history.params start_params = data.start_params registry = get_registry(extended=True) hist_arr = np.array([tree_just_flatten(p, registry=registry) for p in history]).T names = leaf_names(start_params, registry=registry) if selector is not None: flat, treedef = tree_flatten(start_params, registry=registry) helper = tree_unflatten(treedef, list(range(len(flat))), registry=registry) selected = np.array(tree_just_flatten(selector(helper), registry=registry)) names = [names[i] for i in selected] hist_arr = hist_arr[selected] lines: list[LineData] = [] for name, _data in zip(names, hist_arr, strict=False): if max_evaluations is not None and len(_data) > max_evaluations: plot_data = _data[:max_evaluations] else: plot_data = _data line_data = LineData( x=np.arange(len(plot_data)), y=plot_data, color=next(palette_cycle), name=name, show_in_legend=True, ) lines.append(line_data) return lines