# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
from numbers import Real
import warnings
from typing import (
Sequence, Optional,
Union, Tuple, List,
Literal, Any, Dict
)
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.colors import to_rgba
import numpy as np
import pandas as pd
from sklearn.utils.validation import check_array, check_consistent_length
from sklearn.metrics import (
mean_absolute_error,
mean_squared_error,
mean_absolute_percentage_error
)
from ..api.docs import DocstringComponents, _shared_metric_plot_params
from ..api.types import (
MetricFunctionType,
PlotKind,
MetricType,
PlotKindWIS,
PlotKindTheilU
)
from ..core.handlers import columns_manager
from ..core.io import _get_valid_kwargs
from ..core.checks import exist_features
from ..core.diagnose_q import validate_quantiles
from ..utils.generic_utils import are_all_values_in_bounds
__all__= [
'plot_coverage',
'plot_crps',
'plot_mean_interval_width',
'plot_prediction_stability',
'plot_quantile_calibration',
'plot_theils_u_score',
'plot_time_weighted_metric',
'plot_weighted_interval_score',
'plot_qce_donut',
'plot_radar_scores'
]
_param_docs = DocstringComponents.from_nested_components(
base=DocstringComponents(_shared_metric_plot_params),
)
[docs]
def plot_theils_u_score(
y_true: np.ndarray,
y_pred: np.ndarray,
metric_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: PlotKindTheilU = 'summary_bar',
figsize: Tuple[float, float] = (8, 6),
title: Optional[str] = "Theil's U Statistic",
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'chocolate',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
reference_line_at_1: bool = True,
reference_line_props: Optional[Dict[str, Any]] = None,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# *********************************************************
from ..metrics._registry import get_metric
theils_u_score = get_metric("theils_u_score")
# *********************************************************
# --- 1. Input Validation and Preparation ---
# y_true, y_pred: (T,), (N,T), or (N,O,T)
# Metric function handles detailed shape validation.
# Here, we primarily pass them through.
y_true_arr = check_array(
y_true, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=False
)
y_pred_arr = check_array(
y_pred, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=False
)
if y_true_arr.shape != y_pred_arr.shape:
raise ValueError("y_true and y_pred must have the same shape.")
# Determine n_outputs for labeling if raw_values used
# Based on the shape convention of theils_u_score's y_true_proc
n_outputs = 1
if y_true_arr.ndim == 3: # (N,O,T)
n_outputs = y_true_arr.shape[1] # noqa
elif y_true_arr.ndim == 1 and y_true_arr.shape[0] < 2 : # (T,) with T<2
pass # Will be caught by metric
elif y_true_arr.ndim == 2 and y_true_arr.shape[1] < 2: # (N,T) with T<2
pass # will be caught by metric
if y_true_arr.size == 0 or \
(y_true_arr.ndim > 0 and y_true_arr.shape[-1] < 2):
# Metric itself will raise error for T<2, but good to catch empty early
warnings.warn(
"Input data is empty or has fewer than 2 time steps. "
"Cannot generate Theil's U plot."
)
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "Theil's U Plot (No/Invalid Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average',
'eps': 1e-8,
'verbose': 0 # Metric's internal verbose
}
# --- Plotting Logic ---
if kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if metric_values is not None:
scores_to_plot = metric_values
if verbose > 0:
print(f"Using pre-computed Theil's U: {scores_to_plot}")
else:
# For summary bar, respect user's multioutput choice
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(
theils_u_score, effective_kws
)
scores_to_plot = theils_u_score(
y_true_arr, y_pred_arr, **cleaned_kws
)
if verbose > 0:
print(f"Computed Theil's U for summary: {scores_to_plot}")
scores_arr_bar: np.ndarray
x_labels_bar: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and \
scores_to_plot.ndim == 0):
scores_arr_bar = np.array([scores_to_plot])
x_labels_bar = ["Theil's U"]
elif isinstance(scores_to_plot, np.ndarray) and \
scores_to_plot.ndim == 1:
scores_arr_bar = scores_to_plot
x_labels_bar = [f'Output {i}' for i in range(len(scores_arr_bar))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(
f"Unexpected type/shape for Theil's U scores: "
f"{type(scores_to_plot)}"
)
bars = ax.bar(x_labels_bar, scores_arr_bar, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or "Theil's U Statistic")
# Reference line at U=1
if reference_line_at_1:
ref_line_defaults = {
'color': 'black', 'linestyle': '--', 'linewidth': 1
}
ref_props = {**ref_line_defaults, **(reference_line_props or {})}
ax.axhline(1, **ref_props) # type: ignore
# Auto-adjust y-limits
if scores_arr_bar.size > 0 and not np.all(np.isnan(scores_arr_bar)):
min_val = np.nanmin(scores_arr_bar)
max_val = np.nanmax(scores_arr_bar)
# Ensure y=1 is visible if reference line is plotted
plot_min_y = min(min_val, 0.8 if reference_line_at_1 else min_val)
plot_max_y = max(max_val, 1.2 if reference_line_at_1 else max_val)
padding = 0.1 * abs(plot_max_y - plot_min_y) \
if abs(plot_max_y - plot_min_y) > 1e-6 else 0.1
ax.set_ylim(plot_min_y - padding, plot_max_y + padding + 0.05)
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
dy = 3 if yval >= 0 else -10
va = 'bottom' if yval >= 0 else 'top'
ax.annotate(
score_annotation_format.format(yval),
xy=(x, yval),
xytext=(0, dy),
textcoords="offset points",
ha="center",
va=va,
)
else:
# Currently, only 'summary_bar' is implemented for Theil's U.
# Other kinds like distribution of squared errors could be added.
raise ValueError(
f"Unknown plot kind: '{kind}'. "
"Currently only 'summary_bar' is supported for Theil's U."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_theils_u_score.__doc__=r"""
Visualise Theil’s U statistic.
A single‑bar (or multi‑bar) summary plot that benchmarks a model’s
error against a naïve “last‑value’’ forecast.
*U < 1* implies the model improves upon the naïve baseline;
*U = 1* indicates parity;
*U > 1* denotes under‑performance.
Parameters
----------
{params.base.y_true}
{params.base.y_pred}
metric_values : float or ndarray, optional
Pre‑computed Theil’s U statistic(s). When supplied the helper
skips internal evaluation and plots the given number(s) verbatim.
metric_kws : dict, optional
Additional keyword arguments forwarded to
:func:`fusionlab.metrics.theils_u_score`
(e.g. ``multioutput='raw_values'``).
kind : {{'summary_bar'}}, default ``'summary_bar'``
Currently only a bar‑chart summary is available. Additional kinds
may be added in future releases.
reference_line_at_1 : bool, default ``True``
Draw a horizontal reference line at *U = 1* to highlight the
naïve‑benchmark threshold.
reference_line_props : dict, optional
Matplotlib style overrides for the reference line
(colour, linestyle, linewidth …).
{params.base.figsize}
{params.base.title}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
The axes object with the rendered plot.
Notes
-----
For a univariate series the statistic is
.. math::
U = \sqrt{{\frac{{\sum_{{t=2}}^{{T}} \bigl(y_t - \hat y_t\bigr)^2}}
{{\sum_{{t=2}}^{{T}} \bigl(y_t - y_{{t-1}}\bigr)^2}}}}
where :math:`y_{{t-1}}` is the naïve forecast.
The helper calls :func:`fusionlab.metrics.theils_u_score` for the
computation.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_theils_u_score
>>> rng = np.random.default_rng(0)
>>> y_true = rng.normal(size=100)
>>> y_pred = y_true + rng.normal(scale=.2, size=100)
>>> plot_theils_u_score(y_true, y_pred,
... bar_color='steelblue',
... figsize=(6, 4))
>>> plt.show()
See Also
--------
fusionlab.metrics.theils_u_score
Metric implementation.
fusionlab.plot.evaluation.plot_crps
Continuous Ranked Probability Score visualiser.
fusionlab.plot.evaluation.plot_wis
Weighted Interval Score plot.
References
----------
.. [1] H. Theil, *Applied Economic Forecasting*, North‑Holland, 1966.
.. [2] Makridakis, Wheelwright & Hyndman,
*Forecasting: Methods and Applications*, 3rd ed., 1998.
""".format(params=_param_docs)
def _calculate_per_sample_output_wis(
y_true_so: np.ndarray, # Shape (N, O)
y_median_so: np.ndarray, # Shape (N, O)
y_lower_sok: np.ndarray, # Shape (N, O, K)
y_upper_sok: np.ndarray, # Shape (N, O, K)
alphas_k: np.ndarray, # Shape (K,)
nan_policy: Literal['omit', 'propagate', 'raise'] = 'propagate',
warn_invalid_bounds: bool = True,
# Note: sample_weight is applied *after* these per-sample scores
# when aggregating for the final metric value. Here we want raw per-sample.
) -> np.ndarray:
"""
Calculate per-sample, per-output WIS (non-time-weighted).
Returns array of shape (N, O).
NaNs in inputs are handled based on nan_policy.
"""
n_samples, n_outputs = y_true_so.shape
K_intervals = alphas_k.shape[0]
# Expand y_true_so, y_median_so for broadcasting with K: (N,O,1)
y_t_exp_sok = y_true_so[..., np.newaxis]
# y_m_exp_sok = y_median_so[..., np.newaxis] # Not directly used in IS part
# Base NaN mask from y_true, y_median (N,O)
nan_mask_base_so = np.isnan(y_true_so) | np.isnan(y_median_so)
# NaN mask from bounds (N,O), True if any K for that S,O is NaN
nan_mask_bounds_so = np.any(
np.isnan(y_lower_sok) | np.isnan(y_upper_sok), axis=2
)
combined_nan_mask_so = nan_mask_base_so | nan_mask_bounds_so # (N,O)
# Initialize WIS values per sample-output
wis_values_so = np.full((n_samples, n_outputs), np.nan)
# Calculate for valid (non-NaN according to policy) entries
if nan_policy == 'raise' and np.any(combined_nan_mask_so):
raise ValueError("NaNs found in inputs for per-sample WIS.")
# For 'omit', we'd filter rows. For 'propagate', NaNs will flow.
# This helper focuses on calculation; omit filtering happens before calling.
# If called with data already filtered by 'omit', combined_nan_mask_so
# for the passed data should be all False.
mae_term_so = np.abs(y_median_so - y_true_so) # (N,O)
if K_intervals > 0:
interval_width_sok = y_upper_sok - y_lower_sok # (N,O,K)
if warn_invalid_bounds and np.any(y_lower_sok > y_upper_sok):
warnings.warn(
"y_lower > y_upper found in inputs for per-sample WIS. "
"Widths will be negative, affecting score.", UserWarning
)
alphas_exp_k = alphas_k.reshape(1, 1, -1) # (1,1,K)
wis_sharp_sok = (alphas_exp_k / 2.0) * interval_width_sok
wis_under_sok = (y_lower_sok - y_t_exp_sok) * \
(y_t_exp_sok < y_lower_sok)
wis_over_sok = (y_t_exp_sok - y_upper_sok) * \
(y_t_exp_sok > y_upper_sok)
sum_interval_comps_so = np.sum( # Sum over K
wis_sharp_sok + wis_under_sok + wis_over_sok, axis=2
) # (N,O)
wis_values_so = (mae_term_so + sum_interval_comps_so) / \
(K_intervals + 1.0)
else: # K_intervals is 0, WIS is just MAE of median
wis_values_so = mae_term_so
if nan_policy == 'propagate':
wis_values_so = np.where(
combined_nan_mask_so, np.nan, wis_values_so
)
# If nan_policy='omit', NaNs should have been filtered before this helper.
# If nan_policy='raise', error would have been raised.
return wis_values_so
[docs]
def plot_weighted_interval_score(
y_true: np.ndarray,
y_median: np.ndarray,
y_lower: np.ndarray,
y_upper: np.ndarray,
alphas: np.ndarray,
metric_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: PlotKindWIS = 'summary_bar',
output_idx: Optional[int] = None,
hist_bins: Union[int, Sequence[Real], str] = 'auto',
hist_color: str = 'mediumseagreen',
hist_edgecolor: str = 'black',
figsize: Tuple[float, float] = (10, 6),
title: Optional[str] = "Weighted Interval Score (WIS)",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'mediumseagreen',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score_on_title: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# ****************************************************************
from ..metrics._registry import get_metric
weighted_interval_score = get_metric("weighted_interval_score")
# ****************************************************************
# --- 1. Input Validation and Preparation ---
y_true_arr = check_array(
y_true, ensure_2d=False, dtype="numeric",
force_all_finite=False, copy=True)
y_median_arr = check_array(
y_median, ensure_2d=False, dtype="numeric",
force_all_finite=False, copy=True)
y_lower_arr = check_array(
y_lower, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True)
y_upper_arr = check_array(
y_upper, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True)
alphas_arr = check_array(
alphas, ensure_2d=False, dtype="numeric",
force_all_finite=True)
are_all_values_in_bounds(
alphas_arr , bounds= (0, 1), nan_policy='raise',
message = "All alpha values must be in (0,1)."
)
if alphas_arr.ndim > 1: alphas_arr = alphas_arr.squeeze()
if alphas_arr.ndim == 0: alphas_arr = alphas_arr.reshape(1,)
K_intervals = alphas_arr.shape[0]
# Reshape inputs for consistent processing:
# y_true_proc, y_median_proc: (N, O)
# y_lower_proc, y_upper_proc: (N, O, K)
y_true_ndim_orig = y_true_arr.ndim
if y_true_ndim_orig == 1: # (N,)
y_true_proc = y_true_arr.reshape(-1, 1)
y_median_proc = y_median_arr.reshape(-1, 1)
if y_lower_arr.ndim == 2 and \
y_lower_arr.shape[1] == K_intervals: # (N,K)
y_lower_proc = y_lower_arr.reshape(y_lower_arr.shape[0], 1, -1)
y_upper_proc = y_upper_arr.reshape(y_upper_arr.shape[0], 1, -1)
else:
raise ValueError("Shape mismatch for 1D y_true with bounds.")
elif y_true_ndim_orig == 2: # (N,O)
y_true_proc = y_true_arr
y_median_proc = y_median_arr
if y_lower_arr.ndim == 3 and \
y_lower_arr.shape[1] == y_true_proc.shape[1] and \
y_lower_arr.shape[2] == K_intervals: # (N,O,K)
y_lower_proc, y_upper_proc = y_lower_arr, y_upper_arr
else:
raise ValueError("Shape mismatch for 2D y_true with bounds.")
else:
raise ValueError("y_true must be 1D or 2D.")
# Check consistency for all processed shapes
shapes_to_match = (y_true_proc.shape[0], y_true_proc.shape[1]) # (N,O)
if not (y_median_proc.shape == shapes_to_match and \
y_lower_proc.shape[:2] == shapes_to_match and \
y_upper_proc.shape[:2] == shapes_to_match and \
y_lower_proc.shape[2] == K_intervals and \
y_upper_proc.shape[2] == K_intervals):
raise ValueError("Processed input shapes are inconsistent.")
n_samples, n_outputs = y_true_proc.shape
if n_samples == 0:
warnings.warn("Input arrays are empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "WIS Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average',
'warn_invalid_bounds': True,
'eps': 1e-8,
'verbose': 0
}
# --- Plotting Logic ---
if kind == 'scores_histogram':
# Calculate per-sample, per-output WIS values for histogram
# We need to handle NaNs before calling _calculate_per_sample_output_wis
# if nan_policy is 'omit'.
nan_policy_hist = current_metric_kws.get(
'nan_policy', default_kws_for_metric['nan_policy']
)
warn_bounds_hist = current_metric_kws.get(
'warn_invalid_bounds', default_kws_for_metric['warn_invalid_bounds']
)
y_t_hist, y_m_hist = y_true_proc, y_median_proc
y_l_hist, y_u_hist = y_lower_proc, y_upper_proc
s_weights_hist = (current_metric_kws or {}).get('sample_weight', None)
# NaN mask for inputs to _calculate_per_sample_output_wis
nan_mask_base_so_hist = np.isnan(y_t_hist) | np.isnan(y_m_hist)
nan_mask_bounds_so_hist = np.any(
np.isnan(y_l_hist) | np.isnan(y_u_hist), axis=2
)
combined_nan_mask_so_hist = nan_mask_base_so_hist | nan_mask_bounds_so_hist
if np.any(combined_nan_mask_so_hist):
if nan_policy_hist == 'raise':
raise ValueError("NaNs found in inputs for histogram.")
elif nan_policy_hist == 'omit':
rows_with_nan = combined_nan_mask_so_hist.any(axis=1)
rows_to_keep = ~rows_with_nan
if not np.any(rows_to_keep):
ax.text(0.5,0.5,"All samples omitted due to NaNs.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "WIS Scores (No Data)")
return ax
y_t_hist = y_t_hist[rows_to_keep]
y_m_hist = y_m_hist[rows_to_keep]
y_l_hist = y_l_hist[rows_to_keep]
y_u_hist = y_u_hist[rows_to_keep]
if s_weights_hist is not None:
s_weights_hist = s_weights_hist[rows_to_keep]
# If 'propagate', _calculate_per_sample_output_wis will handle it.
if y_t_hist.shape[0] == 0: # All samples omitted
ax.text(0.5,0.5,"No valid samples for WIS histogram.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "WIS Scores (No Data)")
return ax
# Calculate raw per-sample-output WIS scores
# Pass the nan_policy for _calculate to use internally for propagation
per_so_wis_scores = _calculate_per_sample_output_wis(
y_t_hist, y_m_hist, y_l_hist, y_u_hist, alphas_arr,
nan_policy=nan_policy_hist, # This ensures NaNs propagate if needed
warn_invalid_bounds=warn_bounds_hist
) # Shape (N_calc, O)
# Select output for histogram
scores_to_plot_hist: np.ndarray
current_output_label = ""
if n_outputs > 1:
if output_idx is None:
raise ValueError(
"For multi-output data and kind='scores_histogram', "
"'output_idx' must be specified."
)
if not (0 <= output_idx < n_outputs):
raise ValueError(f"output_idx {output_idx} out of bounds.")
scores_to_plot_hist = per_so_wis_scores[:, output_idx]
current_output_label = f" (Output {output_idx})"
else: # Single output
scores_to_plot_hist = per_so_wis_scores.ravel()
valid_scores_for_hist = scores_to_plot_hist[
~np.isnan(scores_to_plot_hist)
]
if valid_scores_for_hist.size > 0:
ax.hist(valid_scores_for_hist, bins=hist_bins,
color=hist_color, edgecolor=hist_edgecolor,
**kwargs.get('hist_kwargs', {}))
if show_score_on_title: # Show mean of the *plotted* scores
mean_of_plotted_wis = np.mean(valid_scores_for_hist)
score_text = f"Mean WIS: {mean_of_plotted_wis:.4f}"
current_title = plot_title_str or \
"Distribution of WIS Values"
plot_title_str = (
f"{current_title}{current_output_label}\n({score_text})"
)
else:
ax.text(0.5,0.5, "No valid WIS values for histogram.",
ha='center', va='center', transform=ax.transAxes)
current_title = plot_title_str or \
"Distribution of WIS Values"
plot_title_str = f"{current_title}{current_output_label} (No Data)"
ax.set_xlabel(xlabel or 'WIS per Sample')
ax.set_ylabel(ylabel or 'Frequency')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
elif kind == 'summary_bar':
scores_to_plot_bar: Union[float, np.ndarray]
if metric_values is not None:
scores_to_plot_bar = metric_values
if verbose > 0:
print(f"Using pre-computed WIS values: {scores_to_plot_bar}")
else:
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(
weighted_interval_score, effective_kws
)
scores_to_plot_bar = weighted_interval_score(
y_true_arr, y_lower_arr, y_upper_arr, y_median_arr, alphas_arr,
**cleaned_kws
)
if verbose > 0:
print(f"Computed WIS for summary: {scores_to_plot_bar}")
scores_arr_bar: np.ndarray
x_labels_bar: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot_bar) or \
(isinstance(scores_to_plot_bar, np.ndarray) and \
scores_to_plot_bar.ndim == 0):
scores_arr_bar = np.array([scores_to_plot_bar])
x_labels_bar = ['Mean WIS']
elif isinstance(scores_to_plot_bar, np.ndarray) and \
scores_to_plot_bar.ndim == 1:
scores_arr_bar = scores_to_plot_bar
x_labels_bar = [f'Output {i}' for i in range(len(scores_arr_bar))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(
f"Unexpected type/shape for WIS scores: {type(scores_to_plot_bar)}"
)
bars = ax.bar(x_labels_bar, scores_arr_bar, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or 'Weighted Interval Score (WIS)')
if scores_arr_bar.size > 0 and not np.all(np.isnan(scores_arr_bar)):
min_val = np.nanmin(scores_arr_bar)
max_val = np.nanmax(scores_arr_bar)
padding = 0.1 * abs(max_val - min_val) if abs(max_val-min_val)>1e-6 else 0.1
ax.set_ylim(min(0, min_val - padding), max_val + padding + 0.05)
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
# choose offset and vertical alignment based on sign
offset = 3 if yval >= 0 else -10
va = 'bottom' if yval >= 0 else 'top'
ax.annotate(
score_annotation_format.format(yval),
xy=(x, yval),
xytext=(0, offset),
textcoords="offset points",
ha="center",
va=va,
)
else:
raise ValueError(
f"Unknown plot kind: '{kind}'. Choose 'scores_histogram' "
"or 'summary_bar'."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_weighted_interval_score.__doc__ =r"""
Visualise Weighted Interval Score (WIS).
WIS aggregates interval widths and coverage penalties across a set of
central prediction intervals, producing a proper scoring rule that
simultaneously rewards *sharpness* and *calibration* of probabilistic
forecasts [1]_.
The helper provides two complementary views:
* **'summary_bar'** – one bar per output (or a single bar for the
uniform average).
* **'scores_histogram'** – the distribution of per‑sample WIS values
for a selected output.
Parameters
----------
{params.base.y_true}
y_median : ndarray
Median (50 % quantile) forecast, shape compatible with
``y_true``.
{params.base.y_lower}
{params.base.y_upper}
alphas : ndarray of shape (K,)
Alpha levels that define the nominal coverage of each prediction
interval: :math:`\alpha_k = 1 - (q_{{k+1}} - q_k)`. Must satisfy
``0 < α < 1`` and be strictly increasing.
metric_values : float or ndarray, optional
Pre‑computed WIS value(s). When supplied, plotting is performed
without recalculating the metric.
metric_kws : dict, optional
Extra keyword arguments forwarded to
:func:`fusionlab.metrics.weighted_interval_score`
(e.g. ``multioutput='raw_values'``).
kind : {{'summary_bar', 'scores_histogram'}}, default ``'summary_bar'``
Style of visualisation.
output_idx : int, optional
Index of the target variable to visualise when
``kind='scores_histogram'`` on multi‑output data.
hist_bins : int | sequence | str, default ``'auto'``
Binning strategy for the histogram (passed to
:func:`matplotlib.pyplot.hist`).
hist_color : str, default ``'mediumseagreen'``
hist_edgecolor : str, default ``'black'``
Bar‑face and edge colours for the histogram.
{params.base.figsize}
title : str, optional
Custom figure title. If *None*, a context‑aware title is generated.
{params.base.xlabel}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
show_score_on_title : bool, default ``True``
Append the mean WIS to the title when
``kind='scores_histogram'``.
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
The axes object containing the plot.
Notes
-----
The weighted interval score for a single observation and :math:`K`
central prediction intervals is
.. math::
\mathrm{{WIS}} \;=\;
\frac{{1}}{{K + 0.5}}\;\Bigl[
\lvert y - \hat{{y}}_{{0.5}}\rvert\;+\;
\sum_{{k=1}}^{{K}} \alpha_k
\bigl\{{\, (y < l_k)\,(l_k - y)
+ (y > u_k)\,(y - u_k)
+ (u_k - l_k) \bigr\}}
\Bigr],
where :math:`[l_k, u_k]` is the :math:`(1-\alpha_k)` central interval.
Lower WIS indicates a sharper, better‑calibrated forecast.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_weighted_interval_score
>>> rng = np.random.default_rng(0)
>>> y_true = rng.normal(size=100)
>>> y_med = y_true + rng.normal(scale=.1, size=100)
>>> y_lower = y_med - 1.0
>>> y_upper = y_med + 1.0
>>> alphas = np.array([0.2])
>>> plot_weighted_interval_score(y_true, y_med,
... y_lower, y_upper, alphas,
... kind='summary_bar',
... bar_color='slateblue')
>>> plt.show()
See Also
--------
fusionlab.metrics.weighted_interval_score
Numerical implementation of WIS.
fusionlab.plot.evaluation.plot_crps
Continuous Ranked Probability Score visualiser.
fusionlab.plot.evaluation.plot_theils_u_score
Deterministic relative‑skill bar plot.
References
----------
.. [1] Bracher, J. et al. *Evaluating Probabilistic Forecasts with
Scoring Rules.* *arXiv preprint* arXiv:2101.05552, 2021.
""".format(params=_param_docs)
def _get_metric_function(
metric_type: MetricType
) -> MetricFunctionType:
"""Helper to retrieve the appropriate metric function."""
from ..metrics._registry import get_metric
if metric_type == 'mae':
return get_metric("time_weighted_mean_absolute_error")
elif metric_type == 'accuracy':
return get_metric("time_weighted_accuracy_score")
elif metric_type == 'interval_score':
return get_metric("time_weighted_interval_score")
else:
# This case should ideally be caught by Literal type hinting
# or earlier validation.
raise ValueError(f"Unknown metric_type: {metric_type}")
def _calculate_per_timestep_values(
metric_type: MetricType,
y_true_sot: np.ndarray, # Shape (N, O, T)
y_pred_sot: Optional[np.ndarray] = None, # Shape (N, O, T)
y_median_sot: Optional[np.ndarray] = None, # Shape (N, O, T)
y_lower_sokt: Optional[np.ndarray] = None, # Shape (N, O, K, T)
y_upper_sokt: Optional[np.ndarray] = None, # Shape (N, O, K, T)
alphas_k: Optional[np.ndarray] = None, # Shape (K,)
nan_policy: Literal['omit', 'propagate', 'raise'] = 'propagate',
verbose : int =0,
) -> np.ndarray:
"""
Calculate per-timestep, un-time-weighted metric values.
Returns array of shape (N, O, T).
NaNs in inputs are handled based on nan_policy.
"""
n_samples, n_outputs, n_timesteps = y_true_sot.shape
per_timestep_vals = np.full((n_samples, n_outputs, n_timesteps), np.nan)
# Base NaN mask from y_true (N,O,T)
nan_mask_base = np.isnan(y_true_sot)
if metric_type == 'mae':
if y_pred_sot is None:
raise ValueError("y_pred is required for MAE.")
nan_mask_pred = np.isnan(y_pred_sot)
combined_nan_mask = nan_mask_base | nan_mask_pred
abs_errors = np.abs(y_pred_sot - y_true_sot)
if nan_policy == 'propagate':
per_timestep_vals = np.where(combined_nan_mask, np.nan, abs_errors)
elif nan_policy == 'omit': # Omit NaNs per S,O,T point for this calculation
per_timestep_vals = np.where(combined_nan_mask, np.nan, abs_errors)
elif nan_policy == 'raise' and np.any(combined_nan_mask):
raise ValueError("NaNs found in inputs for per-timestep MAE.")
else: # No NaNs or policy handled
per_timestep_vals = abs_errors
elif metric_type == 'accuracy':
if y_pred_sot is None:
raise ValueError("y_pred is required for accuracy.")
nan_mask_pred = np.isnan(y_pred_sot)
combined_nan_mask = nan_mask_base | nan_mask_pred
correct_preds = (y_true_sot == y_pred_sot).astype(float)
if nan_policy == 'propagate':
per_timestep_vals = np.where(combined_nan_mask, np.nan, correct_preds)
elif nan_policy == 'omit':
per_timestep_vals = np.where(combined_nan_mask, np.nan, correct_preds)
elif nan_policy == 'raise' and np.any(combined_nan_mask):
raise ValueError("NaNs found for per-timestep accuracy.")
else:
per_timestep_vals = correct_preds
elif metric_type == 'interval_score':
if not all(v is not None for v in [
y_median_sot, y_lower_sokt, y_upper_sokt, alphas_k
]):
raise ValueError(
"y_median, y_lower, y_upper, alphas are required for "
"interval_score."
)
# Assert K_intervals > 0
K_intervals = alphas_k.shape[0] # type: ignore
if K_intervals == 0 and verbose > 0: # type: ignore
warnings.warn("TWIS with K=0 intervals; effectively Time-Weighted MAE.")
nan_mask_median = np.isnan(y_median_sot) # type: ignore
nan_mask_bounds = np.any( # True if any K for that S,O,T is NaN
np.isnan(y_lower_sokt) | np.isnan(y_upper_sokt), axis=2 # type: ignore
) # (N,O,T)
combined_nan_mask = nan_mask_base | nan_mask_median | nan_mask_bounds
# Calculate WIS_sot (non-time-weighted)
mae_term_sot = np.abs(y_median_sot - y_true_sot) # (N,O,T) type: ignore
# Expand y_true_sot for broadcasting with K: (N,O,1,T)
y_t_exp_sokt = y_true_sot[..., np.newaxis, :]
y_t_exp_sokt = np.swapaxes(y_t_exp_sokt, 2, 3) # (N,O,T,1)
# Reshape alphas for broadcasting: (1,1,K,1)
alphas_exp_k = alphas_k.reshape(1, 1, -1, 1) # type: ignore
# Interval components: (N,O,K,T)
interval_width_sokt = y_upper_sokt - y_lower_sokt # type: ignore
wis_sharp_sokt = (alphas_exp_k / 2.0) * interval_width_sokt
wis_under_sokt = (y_lower_sokt - y_t_exp_sokt) * \
(y_t_exp_sokt < y_lower_sokt) # type: ignore
wis_over_sokt = (y_t_exp_sokt - y_upper_sokt) * \
(y_t_exp_sokt > y_upper_sokt) # type: ignore
sum_interval_wis_comps_sot = np.sum( # Sum over K
wis_sharp_sokt + wis_under_sokt + wis_over_sokt, axis=2
) # (N,O,T)
wis_sot = (mae_term_sot + sum_interval_wis_comps_sot) / \
(K_intervals + 1.0) if K_intervals > 0 else mae_term_sot
if nan_policy == 'propagate':
per_timestep_vals = np.where(combined_nan_mask, np.nan, wis_sot)
elif nan_policy == 'omit':
per_timestep_vals = np.where(combined_nan_mask, np.nan, wis_sot)
elif nan_policy == 'raise' and np.any(combined_nan_mask):
raise ValueError("NaNs found for per-timestep interval score.")
else:
per_timestep_vals = wis_sot
else:
raise ValueError(f"Unsupported metric_type for per-timestep: {metric_type}")
return per_timestep_vals
[docs]
def plot_time_weighted_metric(
metric_type: MetricType,
y_true: np.ndarray,
y_pred: Optional[np.ndarray] = None,
y_median: Optional[np.ndarray] = None,
y_lower: Optional[np.ndarray] = None,
y_upper: Optional[np.ndarray] = None,
alphas: Optional[np.ndarray] = None,
time_weights: Optional[Union[Sequence[float], str]] = 'inverse_time',
metric_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: PlotKind = 'summary_bar',
output_idx: Optional[int] = None,
sample_idx: Optional[int] = None,
figsize: Tuple[float, float] = (12, 6),
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
profile_line_color: str = 'royalblue',
profile_line_style: str = '-',
profile_marker: Optional[str] = 'o',
time_weights_color: str = 'gray',
show_time_weights_on_profile: bool = False,
bar_color: Union[str, List[str]] = 'royalblue',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score_on_title: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# --- 1. Input Validation and Metric Function Selection ---
metric_func = _get_metric_function(metric_type)
# Basic validation for y_true
y_true_arr = check_array(y_true, ensure_2d=False, allow_nd=True,
dtype="numeric" if metric_type != 'accuracy' else None,
force_all_finite=False, copy=True)
# Prepare metric_kws, ensuring plot's time_weights is used
current_metric_kws = (metric_kws or {}).copy()
current_metric_kws['time_weights'] = time_weights # Override/set
default_overall_metric_kws = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average',
'eps': 1e-8,
'verbose': 0 # Metric's internal verbose
}
# For overall score calculation if metric_values is None
overall_score_kws = {
**default_overall_metric_kws,
**current_metric_kws
}
# For per-timestep value calculation (nan_policy and eps are relevant)
per_timestep_nan_policy = overall_score_kws.get(
'nan_policy', 'propagate'
) # type: ignore
# per_timestep_eps = overall_score_kws.get('eps', 1e-8) # type: ignore
# --- 2. Reshape Inputs to Standard (N, O, T) or (N, O, K, T) ---
# This part needs to be robust for different y_true_arr.ndim
# and corresponding prediction shapes.
y_true_ndim_orig = y_true_arr.ndim
if y_true_ndim_orig == 1: # (T,) -> (1,1,T)
y_true_proc = y_true_arr.reshape(1,1,-1)
n_samples, n_outputs, n_timesteps = 1, 1, y_true_arr.shape[0]
elif y_true_ndim_orig == 2: # (N,T) -> (N,1,T)
y_true_proc = y_true_arr.reshape(y_true_arr.shape[0], 1, -1)
n_samples, n_outputs, n_timesteps = (
y_true_arr.shape[0], 1, y_true_arr.shape[1]
)
elif y_true_ndim_orig == 3: # (N,O,T)
y_true_proc = y_true_arr
n_samples, n_outputs, n_timesteps = y_true_proc.shape
else:
raise ValueError("y_true must be 1D, 2D, or 3D.")
# Process prediction arrays based on metric_type
y_pred_proc = y_median_proc = y_lower_proc = y_upper_proc = None
alphas_proc = None
if metric_type in ['mae', 'accuracy']:
if y_pred is None:
raise ValueError("y_pred is required for MAE/Accuracy.")
y_p = check_array(
y_pred, ensure_2d=False, allow_nd=True,
dtype=y_true_proc.dtype,
force_all_finite=False
)
if y_p.shape != y_true_arr.shape:
raise ValueError("Shape mismatch: y_true vs y_pred.")
y_pred_proc = y_p.reshape(n_samples, n_outputs, n_timesteps)
elif metric_type == 'interval_score':
if not all(
v is not None for v in [y_median, y_lower, y_upper, alphas]):
raise ValueError(
"y_median, y_lower, y_upper,"
" alphas required for interval_score.")
y_m = check_array(
y_median, ensure_2d=False,
allow_nd=True, dtype="numeric",
force_all_finite=False) # type: ignore
y_l = check_array(y_lower,
ensure_2d=False, allow_nd=True,
dtype="numeric",
force_all_finite=False) # type: ignore
y_u = check_array(
y_upper, ensure_2d=False,
allow_nd=True,
dtype="numeric", force_all_finite=False
) # type: ignore
alphas_proc = check_array(
alphas, ensure_2d=False,
dtype="numeric",
force_all_finite=True
) # type: ignore
if alphas_proc.ndim > 1:
alphas_proc = alphas_proc.squeeze()
if alphas_proc.ndim == 0:
alphas_proc = alphas_proc.reshape(1,)
K_intervals = alphas_proc.shape[0]
if y_m.shape != y_true_arr.shape:
raise ValueError("Shape mismatch: y_true vs y_median.")
y_median_proc = y_m.reshape(
n_samples, n_outputs, n_timesteps)
# Expected y_lower/upper: (N,O,K,T) or compatible
expected_bounds_shape_prefix = (
n_samples, n_outputs, K_intervals)
if (
y_l.ndim == 2 and K_intervals==y_l.shape[0]
and n_timesteps==y_l.shape[1] and n_samples==1
and n_outputs==1
): #(K,T)
y_lower_proc = y_l.reshape(1,1,K_intervals,n_timesteps)
y_upper_proc = y_u.reshape(1,1,K_intervals,n_timesteps)
elif (
y_l.ndim == 3 and y_l.shape[:1]==(n_samples,)
and y_l.shape[1]==K_intervals
and n_outputs==1
): #(N,K,T)
y_lower_proc = y_l.reshape(n_samples,1,K_intervals,n_timesteps)
y_upper_proc = y_u.reshape(n_samples,1,K_intervals,n_timesteps)
elif y_l.ndim == 4 and y_l.shape[:3] == expected_bounds_shape_prefix : # (N,O,K,T)
y_lower_proc, y_upper_proc = y_l, y_u
else:
raise ValueError(
f"y_lower/y_upper shape incompatible. Expected"
" compatible with (N,O,K,T)="
f"{(n_samples,n_outputs,K_intervals,n_timesteps)},"
f" got {y_l.shape}")
if (
y_lower_proc.shape[3] != n_timesteps
or y_upper_proc.shape[3] != n_timesteps
):
raise ValueError("Timestep dimension mismatch in bounds.")
if n_samples == 0 or n_timesteps == 0: # Handled after reshaping for clarity
warnings.warn("Effective data is empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or f"{metric_type.upper()} Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title if title is not None else \
f"Time-Weighted {metric_type.replace('_',' ').title()}"
# --- Plotting Logic ---
if kind == 'time_profile':
if n_timesteps < 1 and metric_type != 'interval_score': # IS can have K=0
warnings.warn("Need at least 1 timestep for time_profile plot.")
# Fallback to empty plot
ax.set_title(plot_title_str + " (Not Enough Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
per_timestep_sot = _calculate_per_timestep_values(
metric_type, y_true_proc, y_pred_proc, y_median_proc,
y_lower_proc, y_upper_proc, alphas_proc,
nan_policy=per_timestep_nan_policy
) # (N,O,T)
# Aggregate over samples if sample_idx is None
profile_data_ot: np.ndarray # Shape (O,T)
sample_weights_for_avg = (current_metric_kws or {}).get('sample_weight', None)
if sample_idx is not None:
if not (0 <= sample_idx < n_samples):
raise ValueError(f"sample_idx {sample_idx} out of bounds.")
profile_data_ot = per_timestep_sot[sample_idx, :, :] # (O,T)
plot_title_str += f" (Sample {sample_idx})"
else: # Average over samples
if sample_weights_for_avg is not None:
s_w = check_array(sample_weights_for_avg, ensure_2d=False,
dtype="numeric", force_all_finite=True)
check_consistent_length(per_timestep_sot, s_w)
if np.sum(s_w) < default_overall_metric_kws['eps']: # type: ignore
profile_data_ot = np.full((n_outputs, n_timesteps), np.nan)
else: # Weighted average, careful with NaNs in per_timestep_sot
profile_data_ot = np.ma.average(
np.ma.masked_invalid(per_timestep_sot), # type: ignore
axis=0, weights=s_w
)
if isinstance(profile_data_ot, np.ma.MaskedArray):
profile_data_ot = profile_data_ot.filled(np.nan)
else:
profile_data_ot = np.nanmean(per_timestep_sot, axis=0)
# Select output
profile_to_plot_t: np.ndarray # Shape (T,)
current_output_label = ""
output_idx_to_use = 0
if n_outputs > 1:
if output_idx is None:
warnings.warn(
"Multi-output data for time_profile without specified "
"'output_idx'. Plotting first output or average if applicable."
)
# Default to first output or average if appropriate for metric
# For now, let's default to first output for profile plot
profile_to_plot_t = profile_data_ot[0, :]
current_output_label = " (Output 0)"
elif not (0 <= output_idx < n_outputs):
raise ValueError(f"output_idx {output_idx} out of bounds.")
else:
profile_to_plot_t = profile_data_ot[output_idx, :]
current_output_label = f" (Output {output_idx})"
output_idx_to_use = output_idx
else: # Single output
profile_to_plot_t = profile_data_ot.ravel() # Should be (1,T) -> (T,)
time_steps_x = np.arange(n_timesteps)
ax.plot(time_steps_x, profile_to_plot_t,
color=profile_line_color, linestyle=profile_line_style,
marker=profile_marker if profile_marker else '',
label=f"{metric_type.upper()} Profile",
**kwargs.get('plot_kwargs', {}))
ax.set_xlabel(xlabel or "Time Step")
ax.set_ylabel(ylabel or f"Per-Timestep {metric_type.upper()}")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
if show_time_weights_on_profile and n_timesteps > 0:
# Process actual time_weights for plotting
w_t_plot: np.ndarray
if time_weights is None:
w_t_plot = np.full(n_timesteps, 1.0/n_timesteps if n_timesteps > 0 else 0)
elif isinstance(time_weights, str) and time_weights == 'inverse_time':
if n_timesteps == 0: w_t_plot = np.array([])
else:
w_raw_plot = 1./np.arange(1,n_timesteps+1)
sum_w_plot = np.sum(w_raw_plot)
w_t_plot = w_raw_plot/(
sum_w_plot
if sum_w_plot > default_overall_metric_kws['eps']
else 1) # type: ignore
else:
w_t_plot = check_array(
time_weights,ensure_2d=False,
dtype="numeric",
force_all_finite=True
) # type: ignore
if w_t_plot.shape[0]!=n_timesteps: raise ValueError(
"time_weights length mismatch for plot.")
sum_w_t_plot = np.sum(w_t_plot)
if sum_w_t_plot < default_overall_metric_kws['eps']: # type: ignore
w_t_plot = np.zeros(
n_timesteps) if not np.any(
w_t_plot!=0) else w_t_plot
else: w_t_plot = w_t_plot / sum_w_t_plot
if w_t_plot.size == n_timesteps : # Ensure it was processed correctly
ax2 = ax.twinx()
ax2.bar(time_steps_x, w_t_plot, alpha=0.3, width=0.8,
color=time_weights_color, label='Time Weights')
ax2.set_ylabel('Time Weight', color=time_weights_color)
ax2.tick_params(axis='y', labelcolor=time_weights_color)
# Ensure legend includes items from both axes
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2, loc='best')
if show_score_on_title:
score_for_title: Optional[Union[float, np.ndarray]] = None
if metric_values is not None: # Use pre-computed
# If metric_values is array (raw multioutput), select the one for title
if isinstance(metric_values, np.ndarray) and n_outputs > 1:
score_for_title =(
metric_values[output_idx_to_use]
if output_idx is not None
and output_idx_to_use < len(metric_values)
else np.nanmean(metric_values)
)
else: score_for_title = metric_values
else: # Calculate overall score for the plotted output/average
title_score_kws = {**overall_score_kws}
# For title, if specific output/sample plotted, score that.
# If averaged profile, then overall score.
# This can get complex. Simplest: show overall score from metric_kws.
if n_outputs > 1 and output_idx is not None:
title_score_kws['multioutput'] = 'raw_values'
else: # Single output or averaged profile
title_score_kws['multioutput'] = 'uniform_average'
cleaned_title_kws = _get_valid_kwargs(metric_func, title_score_kws)
# Prepare inputs for metric_func based on metric_type
metric_inputs = {'y_true': y_true_arr}
if metric_type in ['mae', 'accuracy']: metric_inputs['y_pred'] = y_pred # type: ignore
elif metric_type == 'interval_score':
metric_inputs.update({
'y_median': y_median, 'y_lower': y_lower, # type: ignore
'y_upper': y_upper, 'alphas': alphas # type: ignore
})
try:
calculated_scores = metric_func(**metric_inputs, **cleaned_title_kws)
if (
isinstance(calculated_scores, np.ndarray)
and n_outputs > 1
and output_idx is not None
):
score_for_title = calculated_scores[output_idx_to_use]
else:
score_for_title = float(np.ravel(calculated_scores)[0]) # type: ignore
except Exception as e:
warnings.warn(f"Could not calculate score for title: {e}")
if (
score_for_title is not None
and not np.isnan(score_for_title)
): # type: ignore
score_text = f"Overall Score: {score_for_title:.4f}"
plot_title_str = (
f"{plot_title_str}{current_output_label}\n({score_text})"
)
elif kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if metric_values is not None:
scores_to_plot = metric_values
if verbose > 0: print(f"Using pre-computed scores: {scores_to_plot}")
else:
summary_bar_kws = {**overall_score_kws, **current_metric_kws}
# Respect user's multioutput for summary bar
if 'multioutput' not in current_metric_kws:
summary_bar_kws['multioutput'] = 'uniform_average'
cleaned_kws = _get_valid_kwargs(metric_func, summary_bar_kws)
metric_inputs = {'y_true': y_true_arr}
if metric_type == 'mae': metric_inputs['y_pred'] = y_pred # type: ignore
elif metric_type == 'accuracy': metric_inputs['y_pred'] = y_pred # type: ignore
elif metric_type == 'interval_score':
metric_inputs.update({
'y_median': y_median, 'y_lower': y_lower, # type: ignore
'y_upper': y_upper, 'alphas': alphas # type: ignore
})
scores_to_plot = metric_func(**metric_inputs, **cleaned_kws)
if verbose > 0: print(f"Computed scores for summary: {scores_to_plot}")
scores_arr_bar: np.ndarray
x_labels_bar: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_overall_metric_kws['multioutput'])
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 0):
scores_arr_bar = np.array([scores_to_plot])
x_labels_bar = [f"Overall {metric_type.upper()}"]
elif isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 1:
scores_arr_bar = scores_to_plot
x_labels_bar = [f'Output {i}' for i in range(len(scores_arr_bar))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(f"Unexpected scores type/shape: {type(scores_to_plot)}")
bars = ax.bar(x_labels_bar, scores_arr_bar, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or f"{metric_type.upper()} Score")
if scores_arr_bar.size > 0 and not np.all(np.isnan(scores_arr_bar)):
min_val = np.nanmin(scores_arr_bar)
max_val = np.nanmax(scores_arr_bar)
padding = 0.1 * abs(max_val - min_val) if abs(max_val - min_val) > 1e-6 else 0.1
# Adjust y_lim based on metric type (accuracy vs error)
if metric_type == 'accuracy':
ax.set_ylim(max(0, min_val - padding), min(1, max_val + padding + 0.05))
else: # MAE, Interval Score (errors, can be >1 or <0 for IS)
ax.set_ylim(min_val - padding if min_val < 0 else 0,
max_val + padding + 0.05)
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
offset_y = 3 if yval >= 0 else -10
va = 'bottom' if yval >= 0 else 'top'
ax.annotate(
score_annotation_format.format(yval),
xy=(x, yval),
xytext=(0, offset_y),
textcoords="offset points",
ha="center",
va=va,
)
else:
raise ValueError(f"Unknown plot kind: '{kind}'.")
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_time_weighted_metric.__doc__=r"""
Visualise time‑weighted error / accuracy metrics (MAE, classification
accuracy, or interval‑based scores) as either
* a **summary bar** of the overall time‑weighted score, or one bar per
output dimension; or
* a **time‑profile** curve that shows how the metric evolves over the
forecasting horizon, optionally overlaid with the weight
distribution.
The helper delegates numeric computation to the corresponding
metric in :pymod:`fusionlab.metrics` and applies the chosen *time
weights* before visualisation.
Parameters
----------
metric_type : {{'mae', 'accuracy', 'interval_score'}}
Which metric to compute and plot.
* ``'mae'`` – Mean Absolute Error.
* ``'accuracy'`` – Classification accuracy.
* ``'interval_score'`` – Weighted interval score
(requires median, bounds, and ``alphas``).
{params.base.y_true}
{params.base.y_pred}
{params.base.y_median}
{params.base.y_lower}
{params.base.y_upper}
{params.base.alphas}
time_weights : 1‑D sequence, ``'inverse_time'`` or ``None``,\
default ``'inverse_time'``
* **array‑like** – explicit non‑negative weights for each
timestep *(length T)*. They are automatically normalised to sum
to 1.
* ``'inverse_time'`` – use
:math:`w_t \propto 1 / (t + 1)` (early timesteps matter more).
* ``None`` – uniform weights *(1/T)*.
metric_values : float or ndarray, optional
Pre‑computed time‑weighted score(s) to plot, bypassing internal
metric evaluation.
{params.base.metric_kws}
kind : {{'summary_bar', 'time_profile'}}, default ``'summary_bar'``
* **summary_bar** – bar plot of the overall score.
* **time_profile** – line plot of the per‑timestep metric
(averaged over samples), optionally with the weight profile.
output_idx : int, optional
Output dimension to plot when the data are multi‑output and
``kind='time_profile'``.
{params.base.sample_idx}
figsize, title, xlabel, ylabel
{{see shared parameters below}}
Time‑profile styling
^^^^^^^^^^^^^^^^^^^^
profile_line_color : str, default ``'royalblue'``
profile_line_style : str, default ``'-'``
profile_marker : str or ``None``, default ``'o'``
Matplotlib properties for the metric curve.
time_weights_color : str, default ``'gray'``
show_time_weights_on_profile : bool, default ``False``
If *True*, draws a semi‑transparent bar chart of the
normalised weights on a secondary y‑axis.
Summary‑bar styling
^^^^^^^^^^^^^^^^^^^
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
Common plot controls
^^^^^^^^^^^^^^^^^^^^
{params.base.figsize}
{params.base.show_score_on_title}
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
The axes object with the rendered figure.
Notes
-----
Let :math:`w_t` be the *normalised* time weight for horizon *t*.
For MAE the time‑weighted score is
.. math::
\text{{TW‑MAE}} \;=\; \sum_{{t=1}}^{{T}} w_t \,
\lvert y_t - \hat y_t\rvert .
Analogous definitions apply for accuracy (with the 0‑1 loss) and
for the weighted interval score (using the per‑timestep WIS).
If *``kind='time_profile'``* the helper first computes the unweighted
metric value for each timestep, then applies the ``time_weights`` when
plotting or when aggregating to a single title score.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_time_weighted_metric
>>> T = 24
>>> y_true = np.sin(np.linspace(0, 3*np.pi, T))
>>> y_pred = y_true + np.random.normal(0, 0.1, T)
>>> ax = plot_time_weighted_metric(
... metric_type='mae',
... y_true=y_true,
... y_pred=y_pred,
... kind='time_profile',
... show_time_weights_on_profile=True,
... figsize=(8, 4))
>>> plt.show()
See Also
--------
fusionlab.metrics.time_weighted_mae
fusionlab.metrics.time_weighted_accuracy
fusionlab.metrics.time_weighted_interval_score
fusionlab.plot.evaluation.plot_weighted_interval_score
References
----------
.. [1] Tay, F.E.H., *et al.* “Application of Weighted Metrics in Time‑
Series Forecast Evaluation,” *International Journal of Forecasting*,
vol 35, 2019.
""".format(params=_param_docs)
[docs]
def plot_quantile_calibration(
y_true: np.ndarray,
y_pred_quantiles: np.ndarray,
quantiles: np.ndarray,
qce_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: Literal['reliability_diagram', 'summary_bar'] = 'reliability_diagram',
output_idx: Optional[int] = None,
perfect_calib_color: str = 'red',
observed_prop_color: str = 'blue',
observed_prop_marker: str = 'o',
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = "Quantile Calibration Error (QCE)",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'darkcyan',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# *************************************************************************
from ..metrics._registry import get_metric
quantile_calibration_error = get_metric("quantile_calibration_error")
# **************************************************************************
# --- Input Validation and Preparation ---
# y_true: (N,), (N,O)
# y_pred_quantiles: (N,Q) or (N,O,Q)
# quantiles: (Q,)
y_true_arr = check_array(
y_true, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
y_pred_q_arr = check_array(
y_pred_quantiles, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
q_arr = check_array(
quantiles, ensure_2d=False, dtype="numeric",
force_all_finite=True # Quantiles must be finite and in (0,1)
)
are_all_values_in_bounds(
q_arr , bounds= (0, 1), nan_policy='raise',
message = "All quantile values must be in (0,1)."
)
if q_arr.ndim > 1: q_arr = q_arr.squeeze()
if q_arr.ndim == 0: q_arr = q_arr.reshape(1,)
n_quantiles = q_arr.shape[0]
# Reshape inputs for consistent processing:
# y_true_proc: (N, O)
# y_pred_proc: (N, O, Q)
y_true_ndim_orig = y_true_arr.ndim
if y_true_ndim_orig == 1: # (N,)
y_true_proc = y_true_arr.reshape(-1, 1) # (N,1)
if y_pred_q_arr.ndim == 2 and \
y_pred_q_arr.shape[1] == n_quantiles: # (N,Q)
y_pred_proc = y_pred_q_arr.reshape(
y_pred_q_arr.shape[0], 1, -1 # (N,1,Q)
)
else:
raise ValueError(
"If y_true is 1D, y_pred_quantiles must be 2D (N,Q)."
)
elif y_true_ndim_orig == 2: # (N,O)
y_true_proc = y_true_arr
if y_pred_q_arr.ndim == 3 and \
y_pred_q_arr.shape[1] == y_true_proc.shape[1] and \
y_pred_q_arr.shape[2] == n_quantiles: # (N,O,Q)
y_pred_proc = y_pred_q_arr
else:
raise ValueError(
"If y_true is 2D (N,O), y_pred_quantiles must be 3D (N,O,Q)."
)
else:
raise ValueError("y_true must be 1D or 2D.")
if y_true_proc.shape[0] != y_pred_proc.shape[0]: # Samples mismatch
raise ValueError("y_true and y_pred_quantiles n_samples mismatch.")
n_samples, n_outputs, _ = y_pred_proc.shape
if n_samples == 0:
warnings.warn("Input arrays are empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "QCE Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
if n_quantiles == 0 and kind != 'summary_bar':
warnings.warn("No quantiles provided. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or f"QCE {kind} (No Quantiles)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title # Use a mutable string
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average',
'eps': 1e-8,
'verbose': 0
}
# --- Plotting Logic ---
if kind == 'reliability_diagram':
if n_outputs > 1 and output_idx is None:
raise ValueError(
"For multi-output data and kind='reliability_diagram', "
"'output_idx' must be specified."
)
if output_idx is not None and not (0 <= output_idx < n_outputs):
raise ValueError(f"output_idx {output_idx} out of bounds.")
# Select data for the specific output
current_output_to_plot = output_idx if output_idx is not None else 0
y_t_plot = y_true_proc[:, current_output_to_plot]
y_p_plot = y_pred_proc[:, current_output_to_plot, :]
# Calculate observed proportions
# Handle NaNs based on metric_kws for calculating proportions
nan_policy_plot = current_metric_kws.get(
'nan_policy', default_kws_for_metric['nan_policy']
)
sample_weight_plot = current_metric_kws.get('sample_weight', None)
eps_plot = current_metric_kws.get('eps', default_kws_for_metric['eps'])
# Create indicators: (N_samples, N_quantiles)
indicators = (
y_t_plot[:, np.newaxis] <= y_p_plot
).astype(float)
nan_mask_yt_exp = np.isnan(y_t_plot[:, np.newaxis])
nan_mask_yp = np.isnan(y_p_plot)
nan_mask_sq = nan_mask_yt_exp | nan_mask_yp # (N_samples, N_quantiles)
if np.any(nan_mask_sq):
if nan_policy_plot == 'raise':
raise ValueError("NaNs found in data for reliability diagram.")
elif nan_policy_plot == 'omit':
# Omit samples if *any* of their quantiles or true value is NaN
rows_with_nan = nan_mask_sq.any(axis=1) # (N_samples,)
rows_to_keep = ~rows_with_nan
if not np.any(rows_to_keep):
ax.text(0.5,0.5,"All samples omitted due to NaNs.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "Reliability Diagram (No Data)")
return ax
indicators = indicators[rows_to_keep]
if sample_weight_plot is not None:
sample_weight_plot = sample_weight_plot[rows_to_keep]
nan_mask_sq = nan_mask_sq[rows_to_keep] # For propagate consistency
# If 'propagate', NaNs in indicators are handled by nanmean/average
if indicators.shape[0] == 0: # All samples omitted
ax.text(0.5,0.5,"No valid samples for reliability diagram.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "Reliability Diagram (No Data)")
return ax
if nan_policy_plot == 'propagate':
indicators = np.where(nan_mask_sq, np.nan, indicators)
observed_proportions: np.ndarray
if sample_weight_plot is not None:
sum_sw = np.sum(sample_weight_plot)
if sum_sw < eps_plot:
observed_proportions = np.full(n_quantiles, np.nan)
else:
# Weighted average, careful with NaNs in indicators
temp_props = []
for q_idx in range(n_quantiles):
valid_inds_q = indicators[:, q_idx]
finite_mask_q = ~np.isnan(valid_inds_q)
if np.any(finite_mask_q):
sum_finite_weights = np.sum(sample_weight_plot[finite_mask_q])
if sum_finite_weights >= eps_plot:
prop = np.sum(
valid_inds_q[finite_mask_q] * \
sample_weight_plot[finite_mask_q]
) / sum_finite_weights
temp_props.append(prop)
else: temp_props.append(np.nan)
else: temp_props.append(np.nan)
observed_proportions = np.array(temp_props)
else:
observed_proportions = np.nanmean(indicators, axis=0)
# Plotting reliability diagram
ax.plot([0, 1], [0, 1], linestyle='--', color=perfect_calib_color,
label='Perfect Calibration')
ax.plot(q_arr, observed_proportions, marker=observed_prop_marker,
linestyle='-', color=observed_prop_color,
label='Observed Proportion')
if show_score:
score_for_title: Optional[float] = None
if qce_values is not None:
if n_outputs > 1 and output_idx is not None:
if isinstance(qce_values, np.ndarray) and \
qce_values.ndim == 1 and output_idx < len(qce_values):
score_for_title = qce_values[output_idx]
elif np.isscalar(qce_values) or \
(isinstance(qce_values, np.ndarray) and qce_values.size==1):
score_for_title = float(np.ravel(qce_values)[0])
else: # Calculate score
title_kws = {**default_kws_for_metric, **current_metric_kws}
if n_outputs > 1 and output_idx is not None:
# Score for specific output
title_kws['multioutput'] = 'raw_values'
cleaned_kws = _get_valid_kwargs(quantile_calibration_error, title_kws)
try:
all_output_scores = quantile_calibration_error(
y_true_arr, y_pred_q_arr, q_arr, **cleaned_kws
)
score_for_title = all_output_scores[output_idx]
except Exception as e:
warnings.warn(f"Could not calculate QCE for title: {e}")
else: # Overall score
title_kws['multioutput'] = 'uniform_average'
cleaned_kws = _get_valid_kwargs(quantile_calibration_error, title_kws)
try:
score_for_title = quantile_calibration_error(
y_true_arr, y_pred_q_arr, q_arr, **cleaned_kws
)
except Exception as e:
warnings.warn(f"Could not calculate QCE for title: {e}")
if score_for_title is not None and not np.isnan(score_for_title):
score_text = f"Avg. QCE: {score_for_title:.4f}"
current_title = plot_title_str or "Quantile Reliability Diagram"
output_label_title = f" (Output {output_idx})" \
if n_outputs > 1 and output_idx is not None else ""
plot_title_str = f"{current_title}{output_label_title}\n({score_text})"
ax.set_xlabel(xlabel or "Nominal Quantile Level (q)")
ax.set_ylabel(ylabel or "Observed Proportion (y <= Q_pred(q))")
ax.legend(loc='best')
ax.set_xlim([-0.05, 1.05])
ax.set_ylim([-0.05, 1.05])
elif kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if qce_values is not None:
scores_to_plot = qce_values
if verbose > 0: print(f"Using pre-computed QCE: {scores_to_plot}")
else:
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(quantile_calibration_error, effective_kws)
scores_to_plot = quantile_calibration_error(
y_true_arr, y_pred_q_arr, q_arr, **cleaned_kws
)
if verbose > 0: print(f"Computed QCE for summary: {scores_to_plot}")
scores_arr_bar: np.ndarray
x_labels_bar: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 0):
scores_arr_bar = np.array([scores_to_plot])
x_labels_bar = ['Mean QCE']
elif isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 1:
scores_arr_bar = scores_to_plot
x_labels_bar = [f'Output {i}' for i in range(len(scores_arr_bar))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(f"Unexpected QCE scores type/shape: {type(scores_to_plot)}")
bars = ax.bar(x_labels_bar, scores_arr_bar, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or 'Quantile Calibration Error (QCE)')
if scores_arr_bar.size > 0 and not np.all(np.isnan(scores_arr_bar)):
max_val = np.nanmax(scores_arr_bar)
ax.set_ylim(0, max(0.1, max_val + 0.1 * max_val + 0.05))
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
ax.annotate(
score_annotation_format.format(yval),
xy=(x, yval),
xytext=(0, 3), # offset in points
textcoords="offset points",
ha="center",
va="bottom",
)
else:
raise ValueError(
f"Unknown plot kind: '{kind}'. Choose 'reliability_diagram' "
"or 'summary_bar'."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_quantile_calibration.__doc__=r"""
Visualise Quantile Calibration Error (QCE).
Two complementary views are supported:
* **'reliability_diagram'** – plots the observed proportion
:math:`\Pr(y \le \hat q)` against the nominal quantile level
*q*. Perfect calibration lies on the diagonal.
* **'summary_bar'** – one bar per output (or an overall bar) showing
the time‑weighted QCE score.
Parameters
----------
{params.base.y_true}
{params.base.y_pred_quantiles}
{params.base.quantiles}
qce_values : float or ndarray, optional
Pre‑computed QCE value(s). If supplied, the helper skips internal
metric evaluation.
metric_kws : dict, optional
Extra keyword arguments passed to
:func:`fusionlab.metrics.quantile_calibration_error`.
kind : {{'reliability_diagram', 'summary_bar'}},
default ``'reliability_diagram'``
Choose the visualisation style.
output_idx : int, optional
Output dimension to plot when the data contain multiple outputs
and ``kind='reliability_diagram'``.
perfect_calib_color : str, default ``'red'``
Line colour for the 45‑degree “perfect calibration’’ reference.
observed_prop_color : str, default ``'blue'``
observed_prop_marker : str, default ``'o'``
Style for the observed‑proportion curve.
{params.base.figsize}
{params.base.title}
{params.base.xlabel}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
show_score : bool, default ``True``
Display the average QCE on the plot or title.
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
Axes containing the calibration plot.
Notes
-----
The *quantile calibration error* for one output is
.. math::
\mathrm{{QCE}} \;=\;
\frac{{1}}{{Q}}\sum_{{k=1}}^{{Q}}
\bigl|\,
\hat F_y(q_k) \;-\; q_k
\bigr|,
where :math:`\hat F_y(q_k)` is the empirical cdf evaluated at the
predicted quantile :math:`\hat q_k`.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_quantile_calibration
>>> rng = np.random.default_rng(0)
>>> y_true = rng.normal(size=500)
>>> qs = np.array([0.1, 0.5, 0.9])
>>> y_pred_q = np.quantile(
... y_true[:, None] + rng.normal(scale=.1, size=(500, 3)),
... qs, axis=1).T
>>> plot_quantile_calibration(
... y_true, y_pred_q, qs, kind='reliability_diagram')
>>> plt.show()
See Also
--------
fusionlab.metrics.quantile_calibration_error
Numeric implementation of QCE.
fusionlab.plot.evaluation.plot_weighted_interval_score
Visualises interval‑based probabilistic scores.
fusionlab.plot.evaluation.plot_time_weighted_metric
Time‑weighted MAE / accuracy / interval‑score plots.
References
----------
.. [1] Gneiting, T. & Katzfuss, M. *Probabilistic Forecasting,*
*Annu. Rev. Stat. Appl.*, 2014.
""".format(params=_param_docs)
[docs]
def plot_coverage(
y_true: np.ndarray,
y_lower: np.ndarray,
y_upper: np.ndarray,
coverage_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
sample_indices: Optional[np.ndarray] = None,
output_index: Optional[int] = None,
kind: Literal['intervals', 'summary_bar'] = 'intervals',
figsize: Tuple[float, float] = (12, 6),
title: Optional[str] = "Prediction Interval Coverage",
xlabel: str = 'Sample Index',
ylabel: str = 'Value',
covered_color: str = 'mediumseagreen',
uncovered_color: str = 'salmon',
line_color: Optional[str] = 'dimgray',
line_style: str = '--',
line_width: float = 0.8,
marker: str = 'o',
marker_size: int = 30,
interval_color: str = 'skyblue',
interval_alpha: float = 0.5,
legend: bool = True,
show_score: bool = True,
bar_color: Union[str, List[str]] = 'cornflowerblue',
bar_width: float = 0.8,
score_annotation_format: str = "{:.2%}",
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# ************************************************
from ..metrics._registry import get_metric
coverage_score = get_metric("coverage_score")
# ************************************************
# --- Input Validation and Preparation ---
y_true_arr = check_array(
y_true, ensure_2d=False, force_all_finite=False,
dtype="numeric", copy=True)
y_lower_arr = check_array(
y_lower, ensure_2d=False, force_all_finite=False,
dtype="numeric", copy=True)
y_upper_arr = check_array(
y_upper, ensure_2d=False, force_all_finite=False,
dtype="numeric", copy=True
)
if not (y_true_arr.shape == y_lower_arr.shape == y_upper_arr.shape):
raise ValueError(
"y_true, y_lower, and y_upper must have the same shape."
)
if y_true_arr.ndim > 2:
raise ValueError(
"Inputs y_true, y_lower, y_upper must be 1D or 2D."
)
n_samples = y_true_arr.shape[0]
n_outputs = y_true_arr.shape[1] if y_true_arr.ndim == 2 else 1
if n_samples == 0:
warnings.warn("Input arrays are empty. Cannot generate plot.")
if ax is None:
_, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "Coverage Plot (No Data)")
if show_grid:
ax.grid(**(grid_props or {})) # Apply grid even for empty
return ax
# --- Plotting Setup ---
if ax is None:
# Create new figure and axes if none provided
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title # Use a mutable string for title
# --- Metric Calculation Handling ---
# Consolidate metric_kws for internal calls
current_metric_kws = metric_kws or {}
# Define default kws for coverage_score if not provided by user
# These are used if coverage_values is None.
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average', # Default for single title score
'eps': 1e-8,
'verbose': 0 # Metric's internal verbose, not plot's
}
if kind == 'intervals':
y_t_plot, y_l_plot, y_u_plot = y_true_arr, y_lower_arr, y_upper_arr
current_output_label = ""
output_idx_to_use = 0 # Default for 1D or (N,1) case
if y_true_arr.ndim == 2: # Multi-output or (N,1)
if n_outputs > 1: # Truly multi-output
if output_index is None:
raise ValueError(
"For 2D y_true with >1 output and kind='intervals', "
"'output_index' must be specified."
)
if not (0 <= output_index < n_outputs):
raise ValueError(
f"output_index {output_index} is out of bounds for "
f"{n_outputs} outputs."
)
output_idx_to_use = output_index
# else: n_outputs is 1, output_idx_to_use remains 0
y_t_plot = y_true_arr[:, output_idx_to_use]
y_l_plot = y_lower_arr[:, output_idx_to_use]
y_u_plot = y_upper_arr[:, output_idx_to_use]
if n_outputs > 1: # Add label only if truly multi-output
current_output_label = f" (Output {output_idx_to_use})"
if sample_indices is None:
x_indices = np.arange(n_samples)
else:
x_indices = check_array(sample_indices, ensure_2d=False, copy=False)
check_consistent_length(x_indices, y_t_plot)
nan_mask_plot = np.isnan(y_t_plot) | \
np.isnan(y_l_plot) | \
np.isnan(y_u_plot)
valid_indices = ~nan_mask_plot
x_plot = x_indices[valid_indices]
y_t_plot_valid = y_t_plot[valid_indices]
y_l_plot_valid = y_l_plot[valid_indices]
y_u_plot_valid = y_u_plot[valid_indices]
covered_mask = (y_t_plot_valid >= y_l_plot_valid) & \
(y_t_plot_valid <= y_u_plot_valid)
ax.fill_between(
x_plot, y_l_plot_valid, y_u_plot_valid,
color=interval_color, alpha=interval_alpha,
label='Prediction Interval', **kwargs.get('fill_between_kwargs', {})
)
if line_color: # Renamed from true_line_color
ax.plot(
x_plot, y_t_plot_valid, color=line_color,
linestyle=line_style, linewidth=line_width, # Renamed
label='True Values (line)', **kwargs.get('plot_kwargs', {})
)
ax.scatter(
x_plot[covered_mask], y_t_plot_valid[covered_mask],
color=covered_color, marker=marker, s=marker_size, # Renamed
label='Covered True Value', zorder=3,
**kwargs.get('scatter_kwargs', {})
)
ax.scatter(
x_plot[~covered_mask], y_t_plot_valid[~covered_mask],
color=uncovered_color, marker=marker, s=marker_size, # Renamed
label='Uncovered True Value', zorder=3,
**kwargs.get('scatter_kwargs', {})
)
if show_score: # Renamed from show_score_on_title
score_for_title: Optional[float] = None
if coverage_values is not None:
if y_true_arr.ndim == 2 and n_outputs > 1:
if isinstance(coverage_values, np.ndarray) and \
coverage_values.ndim == 1 and \
output_idx_to_use < len(coverage_values):
score_for_title = coverage_values[output_idx_to_use]
elif np.isscalar(coverage_values): # If user passed overall avg
score_for_title = float(coverage_values)
elif np.isscalar(coverage_values) or \
(isinstance(coverage_values, np.ndarray) and coverage_values.size ==1):
score_for_title = float(np.ravel(coverage_values)[0])
else: # Calculate score for the title
# For title, always calculate a single score for the plotted output
title_kws = {**default_kws_for_metric, **current_metric_kws}
# Ensure multioutput is 'uniform_average' for single score title
title_kws['multioutput'] = 'uniform_average'
cleaned_title_kws = _get_valid_kwargs(coverage_score, title_kws)
score_data_y_true = (
y_true_arr[:, output_idx_to_use]
if y_true_arr.ndim == 2 else y_true_arr
)
score_data_y_lower = (
y_lower_arr[:, output_idx_to_use]
if y_lower_arr.ndim == 2 else y_lower_arr
)
score_data_y_upper = (
y_upper_arr[:, output_idx_to_use]
if y_upper_arr.ndim == 2 else y_upper_arr
)
try:
score_for_title = coverage_score(
score_data_y_true, score_data_y_lower,
score_data_y_upper,
**cleaned_title_kws
)
if verbose > 0:
print(f"Coverage score (for title, output "
f"{output_idx_to_use if y_true_arr.ndim==2 and n_outputs > 1 else ''})"
f": {score_for_title:.4f}")
except Exception as e:
warnings.warn(f"Could not calculate score for title: {e}")
if score_for_title is not None and not np.isnan(score_for_title):
score_text = f"Coverage: {score_for_title:.2%}"
if plot_title_str:
plot_title_str = f"{plot_title_str}{current_output_label}\n({score_text})"
else:
plot_title_str = f"Coverage{current_output_label} ({score_text})"
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if legend:
ax.legend()
elif kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if coverage_values is not None:
scores_to_plot = coverage_values
if verbose > 0:
print(f"Using pre-computed coverage values: {scores_to_plot}")
else:
# For summary bar, respect user's multioutput choice in metric_kws
# Default to 'uniform_average' if not specified
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws: # if user did not specify
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(coverage_score, effective_kws)
scores_to_plot = coverage_score(
y_true_arr, y_lower_arr, y_upper_arr,
**cleaned_kws
)
if verbose > 0:
print(f"Computed coverage score(s) for summary: {scores_to_plot}")
scores_arr: np.ndarray
x_labels: List[str]
multioutput_used_for_score = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 0):
scores_arr = np.array([scores_to_plot])
x_labels = ['Overall Coverage']
elif isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 1:
scores_arr = scores_to_plot
x_labels = [f'Output {i}' for i in range(len(scores_arr))]
if plot_title_str and multioutput_used_for_score == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(
f"Unexpected type or shape for scores: {type(scores_to_plot)}"
)
bars = ax.bar(x_labels, scores_arr, color=bar_color, width=bar_width,
**kwargs.get('bar_kwargs', {}))
ax.set_ylabel('Coverage Score')
ax.set_ylim(0, max(1.1, np.nanmax(
scores_arr) + 0.1 if scores_arr.size > 0 and not np.all(
np.isnan(scores_arr)) else 1.1) )
for bar_val in bars:
yval = bar_val.get_height()
if not np.isnan(yval):
ax.text(bar_val.get_x() + bar_val.get_width()/2.0, yval + 0.02,
score_annotation_format.format(yval),
ha='center', va='bottom')
else:
raise ValueError(
f"Unknown plot kind: '{kind}'. "
"Choose 'intervals' or 'summary_bar'."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_coverage.__doc__ =r"""
Visualise prediction‑interval coverage in two ways:
* **'intervals'** – true values overlaid on their prediction
intervals, coloured by whether each point is covered.
* **'summary_bar'** – bar chart of the empirical coverage rate
(overall or per output).
Parameters
------------
{params.base.y_true}
{params.base.y_lower}
{params.base.y_upper}
coverage_values : float or ndarray, optional
Pre‑computed coverage score(s). If supplied, the helper skips the
internal call to :func:`fusionlab.metrics.coverage_score`.
metric_kws : dict, optional
Extra keyword arguments forwarded to
:func:`fusionlab.metrics.coverage_score` when
``coverage_values`` is *None*.
sample_indices : ndarray, optional
Custom x‑axis locations for the **'intervals'** plot. Must match
the first dimension of ``y_true``.
output_index : int, optional
Output dimension to visualise when the data contain multiple
outputs and ``kind='intervals'``.
kind : {{'intervals', 'summary_bar'}}, default ``'intervals'``
Select the visualisation style.
{params.base.figsize}
title : str, optional
Figure title. If *None*, a context‑aware default is generated.
xlabel : str, default ``'Sample Index'``
ylabel : str, default ``'Value'``
Interval‑plot styling
^^^^^^^^^^^^^^^^^^^^^
covered_color : str, default ``'mediumseagreen'``
uncovered_color : str, default ``'salmon'``
line_color : str or None, default ``'dimgray'``
line_style : str, default ``'--'``
line_width : float, default 0.8
marker : str, default ``'o'``
marker_size : int, default 30
interval_color : str, default ``'skyblue'``
interval_alpha : float, default 0.5
legend : bool, default ``True``
show_score : bool, default ``True``
Append the empirical coverage (as a percentage) to the title of
the *intervals* plot.
Summary‑bar styling
^^^^^^^^^^^^^^^^^^^
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
Axes containing the coverage visualisation.
Notes
-----
The empirical coverage for one output is
.. math::
\widehat C \;=\;
\frac{{1}}{{N}} \sum_{{i=1}}^{{N}}
\mathbb{{1}}\{{\,y_i \in [\ell_i, u_i]\,\}},
where :math:`[\ell_i, u_i]` is the prediction interval for sample *i*.
The helper colours covered points with *covered_color* and uncovered
points with *uncovered_color*.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_coverage
>>> rng = np.random.default_rng(1)
>>> y_true = rng.normal(size=50)
>>> y_lower = y_true - 1.0
>>> y_upper = y_true + 1.0
>>> plot_coverage(y_true, y_lower, y_upper, kind='intervals',
... figsize=(8, 4))
>>> plt.show()
See Also
--------
fusionlab.metrics.coverage_score
Numerical implementation of empirical coverage.
fusionlab.plot.evaluation.plot_weighted_interval_score
Visualises interval sharpness and calibration jointly.
fusionlab.plot.evaluation.plot_quantile_calibration
Reliability diagrams for quantile forecasts.
References
----------
.. [1] Gneiting, T. & Raftery, A.E. (2007). *Strictly Proper Scoring
Rules, Prediction, and Estimation*. *JASA* 102(477), 359‑378.
""".format(params=_param_docs)
[docs]
def plot_crps(
y_true: np.ndarray,
y_pred_ensemble: np.ndarray,
crps_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: Literal['ensemble_ecdf', 'scores_histogram', 'summary_bar'] = 'summary_bar',
sample_idx: int = 0,
output_idx: int = 0,
ecdf_color: str = 'dodgerblue',
true_value_color: str = 'red',
ensemble_marker_color: str = 'gray',
hist_bins: Union[int, Sequence[Real], str] = 'auto',
hist_color: str = 'skyblue',
hist_edgecolor: str = 'black',
figsize: Tuple[float, float] = (10, 6),
title: Optional[str] = "Continuous Ranked Probability Score (CRPS)",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'cornflowerblue',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# *************************************************
from ..metrics._registry import get_metric
continuous_ranked_probability_score = get_metric(
"continuous_ranked_probability_score")
# *************************************************
# --- Input Validation and Preparation ---
# y_true: (N,), (N,O)
# y_pred_ensemble: (N,M), (N,O,M)
y_true_arr = check_array(
y_true, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
y_pred_ensemble_arr = check_array(
y_pred_ensemble, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
# Reshape y_true and y_pred_ensemble for consistent processing
# Target shape for y_true_proc: (N, O)
# Target shape for y_pred_proc: (N, O, M)
# y_true_ndim_orig = y_true_arr.ndim
if y_true_arr.ndim == 1: # (N,) implies single output
y_true_proc = y_true_arr.reshape(-1, 1) # (N, 1)
if y_pred_ensemble_arr.ndim == 2: # (N, M)
y_pred_proc = y_pred_ensemble_arr.reshape(
y_pred_ensemble_arr.shape[0], 1, -1 # (N, 1, M)
)
elif y_pred_ensemble_arr.ndim == 3 and \
y_pred_ensemble_arr.shape[1] == 1: # (N, 1, M)
y_pred_proc = y_pred_ensemble_arr
else:
raise ValueError(
"If y_true is 1D, y_pred_ensemble must be 2D (N,M) "
"or 3D (N,1,M)."
)
elif y_true_arr.ndim == 2: # (N, O)
y_true_proc = y_true_arr
if y_pred_ensemble_arr.ndim == 3 and \
y_pred_ensemble_arr.shape[1] == y_true_proc.shape[1]: # (N,O,M)
y_pred_proc = y_pred_ensemble_arr
else:
raise ValueError(
"If y_true is 2D (N,O), y_pred_ensemble must be 3D (N,O,M) "
"with matching number of outputs."
)
else:
raise ValueError("y_true must be 1D or 2D.")
if y_true_proc.shape[:2] != y_pred_proc.shape[:2]:
raise ValueError(
"Mismatch in n_samples or n_outputs between y_true and "
"y_pred_ensemble after processing."
)
n_samples, n_outputs, n_members = y_pred_proc.shape
if n_samples == 0:
warnings.warn("Input arrays are empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "CRPS Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
if n_members == 0 and kind != 'summary_bar': # summary_bar might use precomputed
warnings.warn(
"No ensemble members in y_pred_ensemble. "
f"Cannot generate '{kind}' plot."
)
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or f"CRPS {kind} (No Ensemble Members)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title # Use a mutable string
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average', # Default for overall score
'verbose': 0 # Metric's internal verbose
}
# For 'scores_histogram', we need per-sample CRPS values.
# The crps_score function, as refactored, returns per-sample, per-output
# scores before the final aggregation if 'multioutput' is 'raw_values'
# and then averages over samples.
# We need to ensure we get raw per-sample scores if kind is histogram.
# --- Plotting Logic ---
if kind == 'ensemble_ecdf':
if not (0 <= sample_idx < n_samples):
raise ValueError(f"sample_idx {sample_idx} out of bounds.")
if not (0 <= output_idx < n_outputs):
raise ValueError(f"output_idx {output_idx} out of bounds.")
sample_ensemble = y_pred_proc[sample_idx, output_idx, :]
sample_true_value = y_true_proc[sample_idx, output_idx]
if np.isnan(sample_true_value) or np.all(np.isnan(sample_ensemble)):
ax.text(0.5, 0.5, "Data for ECDF contains NaNs",
ha='center', va='center', transform=ax.transAxes)
else:
# Remove NaNs from ensemble for ECDF calculation
sample_ensemble_valid = sample_ensemble[~np.isnan(sample_ensemble)]
if sample_ensemble_valid.size == 0:
ax.text(0.5, 0.5, "No valid ensemble members for ECDF",
ha='center', va='center', transform=ax.transAxes)
else:
sorted_ensemble = np.sort(sample_ensemble_valid)
ecdf_y = np.arange(1, len(sorted_ensemble) + 1) / len(sorted_ensemble)
ax.step(sorted_ensemble, ecdf_y, where='post',
color=ecdf_color, label='Ensemble ECDF')
# Plot ensemble members as rug plot or faint points
ax.plot(sorted_ensemble, np.zeros_like(sorted_ensemble) - 0.05,
'|', color=ensemble_marker_color, markersize=10,
alpha=0.5, label='Ensemble Members')
ax.axvline(sample_true_value, color=true_value_color,
linestyle='--', label=f'True Value: {sample_true_value:.2f}')
if show_score:
instance_crps = None
if crps_values is not None: # User provided per-instance scores
# Assuming crps_values could be (N,O) or (N,)
if crps_values.ndim == 2 and \
sample_idx < crps_values.shape[0] and \
output_idx < crps_values.shape[1]:
instance_crps = crps_values[sample_idx, output_idx]
elif crps_values.ndim == 1 and n_outputs == 1 and \
sample_idx < crps_values.shape[0]:
instance_crps = crps_values[sample_idx]
else: # Calculate for this instance
instance_kws = {**default_kws_for_metric, **current_metric_kws}
instance_kws['multioutput'] = 'raw_values' # Get raw for this one
cleaned_kws = _get_valid_kwargs(
continuous_ranked_probability_score, instance_kws)
try:
# Need to pass single sample/output to crps_score
# crps_score expects (N,O) for y_true, (N,O,M) for y_pred
temp_y_true = y_true_proc[sample_idx:sample_idx+1, output_idx:output_idx+1]
temp_y_pred = y_pred_proc[sample_idx:sample_idx+1, output_idx:output_idx+1, :]
calculated_scores = continuous_ranked_probability_score(
temp_y_true, temp_y_pred, **cleaned_kws
) # Should be scalar or (1,) array
instance_crps = float(np.ravel(calculated_scores)[0])
except Exception as e:
warnings.warn(f"Could not calculate CRPS for ECDF title: {e}")
if instance_crps is not None and not np.isnan(instance_crps):
score_text = f"CRPS: {instance_crps:.4f}"
current_title = plot_title_str or "Ensemble ECDF vs True Value"
plot_title_str = f"{current_title}\n({score_text})"
ax.set_xlabel(xlabel or 'Value')
ax.set_ylabel(ylabel or 'Cumulative Probability')
ax.legend(loc='best')
ax.set_ylim(-0.1, 1.1)
elif kind == 'scores_histogram':
# This requires per-sample (and potentially per-output) CRPS values
per_sample_output_crps: Optional[np.ndarray] = None
if crps_values is not None:
if crps_values.ndim == 1 and len(crps_values) == n_samples and n_outputs == 1:
per_sample_output_crps = crps_values # Assumed (N,) for single output
elif crps_values.ndim == 2 and crps_values.shape[0] == n_samples \
and crps_values.shape[1] == n_outputs:
per_sample_output_crps = crps_values # (N,O)
else:
warnings.warn("Provided `crps_values` shape incompatible for histogram.")
else:
hist_kws = {**default_kws_for_metric, **current_metric_kws} # noqa
# XXX TODO:
# To get per-sample, per-output scores, need to call
# crps_score differently
# The current continuous_ranked_probability_score refactor
# returns (N_calc, O) before final aggregation
# This is not directly exposed. We need to call it in a loop
# or modify continuous_ranked_probability_score
# For now, let's assume we can get per-sample scores for a chosen
# output_idx Or average over outputs if output_idx is None.
# Let's compute CRPS for each sample, for a specific
# output_idx or averaged
all_sample_crps_list = []
for s_idx in range(n_samples):
temp_y_true = y_true_proc[s_idx:s_idx+1, ...] # (1,O)
temp_y_pred = y_pred_proc[s_idx:s_idx+1, ...] # (1,O,M)
# Kws for single sample CRPS calculation
single_sample_kws = {**default_kws_for_metric, **current_metric_kws}
single_sample_kws['multioutput'] = 'raw_values' # get (O,) scores
cleaned_kws_ss = _get_valid_kwargs(
continuous_ranked_probability_score, single_sample_kws)
try:
s_crps = continuous_ranked_probability_score(
temp_y_true, temp_y_pred, **cleaned_kws_ss) # (O,)
if n_outputs > 1 and output_idx is not None:
if 0 <= output_idx < n_outputs:
all_sample_crps_list.append(s_crps[output_idx])
else: # Should not happen if validated
all_sample_crps_list.append(np.nan)
else: # Single output or average over outputs for this sample
all_sample_crps_list.append(np.nanmean(s_crps))
except Exception:
all_sample_crps_list.append(np.nan)
per_sample_output_crps = np.array(all_sample_crps_list)
if per_sample_output_crps is not None:
valid_crps_for_hist = per_sample_output_crps[
~np.isnan(per_sample_output_crps)]
if valid_crps_for_hist.size > 0:
ax.hist(valid_crps_for_hist, bins=hist_bins,
color=hist_color, edgecolor=hist_edgecolor,
**kwargs.get('hist_kwargs', {}))
avg_crps_for_title = np.nanmean(valid_crps_for_hist)
if show_score and not np.isnan(avg_crps_for_title):
current_title = plot_title_str or "Distribution of CRPS Values"
plot_title_str = f"{current_title}\n(Mean CRPS: {avg_crps_for_title:.4f})"
else:
ax.text(0.5,0.5, "No valid CRPS scores for histogram",
ha='center', va='center', transform=ax.transAxes)
else:
ax.text(0.5,0.5, "CRPS scores not available for histogram",
ha='center', va='center', transform=ax.transAxes)
ax.set_xlabel(xlabel or 'CRPS per Sample')
ax.set_ylabel(ylabel or 'Frequency')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
elif kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if crps_values is not None:
scores_to_plot = crps_values
if verbose > 0: print(f"Using pre-computed CRPS: {scores_to_plot}")
else:
summary_bar_default_kws = default_kws_for_metric.copy()
# Respect user's multioutput for summary bar
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(
continuous_ranked_probability_score, effective_kws)
scores_to_plot = continuous_ranked_probability_score(
y_true_proc, y_pred_proc, **cleaned_kws
)
if verbose > 0: print(f"Computed CRPS for summary: {scores_to_plot}")
scores_arr: np.ndarray
x_labels: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 0):
scores_arr = np.array([scores_to_plot])
x_labels = ['Overall CRPS']
elif isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 1:
scores_arr = scores_to_plot
x_labels = [f'Output {i}' for i in range(len(scores_arr))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(f"Unexpected scores type/shape: {type(scores_to_plot)}")
bars = ax.bar(x_labels, scores_arr, color=bar_color, width=bar_width,
**kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or 'CRPS')
min_score = np.nanmin(scores_arr) if scores_arr.size > 0 else 0
max_score = np.nanmax(scores_arr) if scores_arr.size > 0 else 0.1
ax.set_ylim(min(0, min_score - 0.1 * abs(min_score)),
max_score + 0.1 * abs(max_score) + 0.05)
for bar in bars:
yval = bar.get_height()
if not np.isnan(yval):
x = bar.get_x() + bar.get_width() / 2.0
label = score_annotation_format.format(yval)
ax.annotate(
label,
xy=(x, yval),
xytext=(0, 3 if yval >= 0 else -10),
textcoords="offset points",
ha="center",
va="bottom" if yval >= 0 else "top",
)
else:
raise ValueError(f"Unknown plot kind: '{kind}'.")
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_crps.__doc__=r"""
Visualise the Continuous Ranked Probability Score (CRPS) for
ensemble forecasts.
Three complementary views are available:
* **'ensemble_ecdf'** – ECDF of a single ensemble, the true value,
and the per‑instance CRPS.
* **'scores_histogram'** – distribution of per‑sample CRPS values.
* **'summary_bar'** – bar chart of the overall CRPS (or one bar per
output).
Parameters
----------
{params.base.y_true}
y_pred_ensemble : ndarray
Ensemble predictions. Shape *(N, M)* for a single output or
*(N, O, M)* for multiple outputs, where *M* is the number of
ensemble members.
crps_values : float or ndarray, optional
Pre‑computed CRPS value(s). If supplied, the helper skips internal
calls to :func:`fusionlab.metrics.continuous_ranked_probability_score`.
{params.base.metric_kws}
kind : {{'ensemble_ecdf', 'scores_histogram', 'summary_bar'}},
default ``'summary_bar'``
Style of plot to generate.
sample_idx : int, default 0
Index of the sample to display when ``kind='ensemble_ecdf'``.
output_idx : int, default 0
Output dimension to display when ``kind='ensemble_ecdf'``.
ecdf_color : str, default ``'dodgerblue'``
true_value_color : str, default ``'red'``
ensemble_marker_color : str, default ``'gray'``
Styling parameters for the ECDF plot.
hist_bins : int | sequence | str, default ``'auto'``
hist_color : str, default ``'skyblue'``
hist_edgecolor : str, default ``'black'``
Histogram styling parameters.
{params.base.figsize}
{params.base.title}
{params.base.xlabel}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
show_score : bool, default ``True``
Append the numeric CRPS to the title (where applicable).
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
Axes containing the CRPS visualisation.
Notes
-----
For one observation with ensemble members
:math:`x_1,\dots,x_M` and true value :math:`y`, the
sample‑based CRPS is
.. math::
\operatorname{{CRPS}} \;=\;
\frac{{1}}{{M}}\sum_{{j=1}}^M |x_j - y|
\;-\;\frac{{1}}{{2M^2}}\sum_{{i=1}}^M\sum_{{j=1}}^M |x_i - x_j|.
Lower scores indicate sharper and better‑calibrated
probabilistic forecasts.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_crps
>>> rng = np.random.default_rng(0)
>>> y_true = rng.normal(size=500)
>>> ens = y_true[:, None] + rng.normal(scale=.5, size=(500, 20))
>>> plot_crps(y_true, ens, kind='scores_histogram')
>>> plt.show()
See Also
--------
fusionlab.metrics.continuous_ranked_probability_score
Numeric computation of sample‑based CRPS.
fusionlab.plot.evaluation.plot_quantile_calibration
Reliability diagrams for quantile forecasts.
fusionlab.plot.evaluation.plot_weighted_interval_score
Interval‑based sharpness and calibration plot.
References
----------
.. [1] Hersbach, H. (2000). *Decomposition of the Continuous Ranked
Probability Score for Ensemble Prediction Systems*.
*Weather and Forecasting*, 15(5), 559‑570.
""".format(params=_param_docs)
[docs]
def plot_mean_interval_width(
y_lower: np.ndarray,
y_upper: np.ndarray,
miw_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: Literal['widths_histogram', 'summary_bar'] = 'summary_bar',
output_idx: Optional[int] = None,
hist_bins: Union[int, Sequence[Real], str] = 'auto',
hist_color: str = 'mediumpurple',
hist_edgecolor: str = 'black',
figsize: Tuple[float, float] = (10, 6),
title: Optional[str] = "Mean Interval Width (Sharpness)",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'mediumpurple',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# *************************************************
from ..metrics._registry import get_metric
mean_interval_width_score = get_metric(
"mean_interval_width_score")
# *************************************************
# --- Input Validation and Preparation ---
# y_lower, y_upper: (N,), (N,O)
y_lower_arr = check_array(
y_lower, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
y_upper_arr = check_array(
y_upper, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
if y_lower_arr.shape != y_upper_arr.shape:
raise ValueError(
"y_lower and y_upper must have the same shape."
)
if y_lower_arr.ndim > 2:
raise ValueError("Inputs y_lower/y_upper must be 1D or 2D.")
# Reshape for consistent processing: (N, O)
y_lower_proc = y_lower_arr.reshape(
-1, 1) if y_lower_arr.ndim == 1 else y_lower_arr
y_upper_proc = y_upper_arr.reshape(
-1, 1) if y_upper_arr.ndim == 1 else y_upper_arr
n_samples, n_outputs = y_lower_proc.shape
if n_samples == 0:
warnings.warn("Input arrays are empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "Mean Interval Width (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title # Use a mutable string
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average', # Default for overall score
'warn_invalid_bounds': True,
'eps': 1e-8,
'verbose': 0 # Metric's internal verbose
}
# --- Plotting Logic ---
if kind == 'widths_histogram':
# For histogram, we need individual widths, calculated from inputs.
# `miw_values` (if mean scores) is not used for this kind.
# Handle NaNs based on metric_kws for calculating widths
# This logic needs to align with how mean_interval_width_score handles it
nan_policy_hist = current_metric_kws.get(
'nan_policy', default_kws_for_metric['nan_policy']
)
temp_y_lower, temp_y_upper = y_lower_proc, y_upper_proc
nan_mask_inputs = np.isnan(temp_y_lower) | np.isnan(temp_y_upper)
if np.any(nan_mask_inputs):
if nan_policy_hist == 'raise':
raise ValueError(
"NaNs found in y_lower/y_upper for histogram."
)
elif nan_policy_hist == 'omit':
# Omit rows where *any* output has NaN for this sample
rows_with_nan = nan_mask_inputs.any(axis=1)
rows_to_keep = ~rows_with_nan
if not np.any(rows_to_keep):
ax.text(0.5,0.5,"All samples omitted due to NaNs.",
ha='center', va='center', transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "Interval Widths (No Data)")
return ax
temp_y_lower = temp_y_lower[rows_to_keep]
temp_y_upper = temp_y_upper[rows_to_keep]
# Update nan_mask_inputs for propagate logic if it were used
nan_mask_inputs = nan_mask_inputs[rows_to_keep]
individual_widths = temp_y_upper - temp_y_lower # (N_calc, O)
if nan_policy_hist == 'propagate':
# nan_mask_inputs is (N_calc, O) if omit was applied to it
individual_widths = np.where(
nan_mask_inputs, np.nan, individual_widths
)
# Select output for histogram if multi-output
widths_to_plot: np.ndarray
current_output_label = ""
if n_outputs > 1:
if output_idx is None:
raise ValueError(
"For multi-output data and kind='widths_histogram', "
"'output_idx' must be specified."
)
if not (0 <= output_idx < n_outputs):
raise ValueError(
f"output_idx {output_idx} out of bounds for "
f"{n_outputs} outputs."
)
widths_to_plot = individual_widths[:, output_idx]
current_output_label = f" (Output {output_idx})"
else: # Single output
widths_to_plot = individual_widths.ravel()
valid_widths_for_hist = widths_to_plot[
~np.isnan(widths_to_plot)
]
if valid_widths_for_hist.size > 0:
ax.hist(valid_widths_for_hist, bins=hist_bins,
color=hist_color, edgecolor=hist_edgecolor,
**kwargs.get('hist_kwargs', {}))
if show_score:
mean_of_plotted_widths = np.mean(valid_widths_for_hist)
score_text = f"Mean Width: {mean_of_plotted_widths:.4f}"
current_title = plot_title_str or \
"Distribution of Interval Widths"
plot_title_str = (
f"{current_title}{current_output_label}\n({score_text})"
)
else:
ax.text(0.5,0.5, "No valid interval widths for histogram.",
ha='center', va='center', transform=ax.transAxes)
current_title = plot_title_str or \
"Distribution of Interval Widths"
plot_title_str = f"{current_title}{current_output_label} (No Data)"
ax.set_xlabel(xlabel or 'Interval Width')
ax.set_ylabel(ylabel or 'Frequency')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
elif kind == 'summary_bar':
scores_to_plot: Union[float, np.ndarray]
if miw_values is not None:
scores_to_plot = miw_values
if verbose > 0:
print(f"Using pre-computed MIW values: {scores_to_plot}")
else:
# For summary bar, respect user's multioutput choice in metric_kws
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(
mean_interval_width_score, effective_kws
)
scores_to_plot = mean_interval_width_score(
y_lower_arr, # Use original full arrays for score calculation
y_upper_arr,
**cleaned_kws
)
if verbose > 0:
print(f"Computed MIW score(s) for summary: {scores_to_plot}")
scores_arr: np.ndarray
x_labels: List[str]
multioutput_used_for_score = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput']
)
if np.isscalar(scores_to_plot) or \
(isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 0):
scores_arr = np.array([scores_to_plot])
x_labels = ['Mean Interval Width']
elif isinstance(scores_to_plot, np.ndarray) and scores_to_plot.ndim == 1:
scores_arr = scores_to_plot
x_labels = [f'Output {i}' for i in range(len(scores_arr))]
if plot_title_str and multioutput_used_for_score == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(
f"Unexpected type or shape for MIW scores: {type(scores_to_plot)}"
)
bars = ax.bar(x_labels, scores_arr, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or 'Mean Interval Width')
# Auto-adjust y-limits for better visualization
if scores_arr.size > 0 and not np.all(np.isnan(scores_arr)):
min_val = np.nanmin(scores_arr)
max_val = np.nanmax(scores_arr)
padding = 0.1 * (max_val - min_val) if (
max_val - min_val) > 1e-6 else 0.1
ax.set_ylim(min(
0, min_val - padding), max_val + padding + 0.05)
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
label = score_annotation_format.format(yval)
ax.annotate(
label,
xy=(x, yval),
xytext=(0, 3 if yval >= 0 else -10),
textcoords="offset points",
ha="center",
va="bottom" if yval >= 0 else "top",
)
else:
raise ValueError(
f"Unknown plot kind: '{kind}'. Choose 'widths_histogram' "
"or 'summary_bar'."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_mean_interval_width.__doc__ =r"""
Visualise Mean Interval Width (MIW) – a simple sharpness measure
equal to the average distance between lower and upper prediction‐
interval bounds.
Two complementary views are implemented:
* **'widths_histogram'** – distribution of individual interval widths
for a chosen output.
* **'summary_bar'** – bar chart of the averaged width (overall or one
bar per output).
Parameters
----------
{params.base.y_lower}
{params.base.y_upper}
miw_values : float or ndarray, optional
Pre‑computed MIW score(s). If supplied the helper skips the
internal call to
:func:`fusionlab.metrics.mean_interval_width_score`.
metric_kws : dict, optional
Extra keyword arguments forwarded to
:func:`fusionlab.metrics.mean_interval_width_score`.
kind : {{'widths_histogram', 'summary_bar'}},
default ``'summary_bar'``
Select the visualisation style.
output_idx : int, optional
Output dimension to plot when ``kind='widths_histogram'`` on
multi‑output data.
hist_bins : int | sequence | str, default ``'auto'``
hist_color : str, default ``'mediumpurple'``
hist_edgecolor : str, default ``'black'``
Styling options for the histogram.
{params.base.figsize}
{params.base.title}
{params.base.xlabel}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
show_score : bool, default ``True``
Display the mean width on the histogram title.
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
Axes containing the MIW visualisation.
Notes
-----
For a single observation the interval width is simply
.. math::
w_i \;=\; u_i \;-\; \ell_i ,
where :math:`u_i` and :math:`\ell_i` are the upper and lower bounds.
The mean interval width over *N* samples is
.. math::
\text{{MIW}} \;=\; \frac{{1}}{{N}}\sum_{{i=1}}^{{N}} w_i.
Lower MIW indicates a *sharper* forecast, but should always be
interpreted together with coverage diagnostics.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_mean_interval_width
>>> rng = np.random.default_rng(1)
>>> y_l = rng.normal(loc=-1.0, scale=.5, size=200)
>>> y_u = y_l + rng.uniform(1.5, 2.5, size=200)
>>> plot_mean_interval_width(
... y_lower=y_l, y_upper=y_u, kind='widths_histogram',
... figsize=(8, 4))
>>> plt.show()
See Also
--------
fusionlab.metrics.mean_interval_width_score
Numeric computation of MIW.
fusionlab.plot.evaluation.plot_coverage
Shows how many observations fall inside the intervals.
fusionlab.plot.evaluation.plot_weighted_interval_score
Combines width with calibration penalties.
References
----------
.. [1] Gneiting, T. & Katzfuss, M. (2014). *Probabilistic Forecasting*.
*Ann. Rev. Stat. Appl.*, 1, 125‑151 — section 4.1, “Sharpness”.
""".format(params=_param_docs)
[docs]
def plot_prediction_stability(
y_pred: np.ndarray,
pss_values: Optional[Union[float, np.ndarray]] = None,
metric_kws: Optional[Dict[str, Any]] = None,
kind: Literal['scores_histogram', 'summary_bar'] = 'summary_bar',
output_idx: Optional[int] = None,
hist_bins: Union[int, Sequence[Real], str] = 'auto',
hist_color: str = 'teal',
hist_edgecolor: str = 'black',
figsize: Tuple[float, float] = (10, 6),
title: Optional[str] = "Prediction Stability Score (PSS)",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
bar_color: Union[str, List[str]] = 'teal',
bar_width: float = 0.8,
score_annotation_format: str = "{:.4f}",
show_score: bool = True,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
# **********************************************************************
from ..metrics._registry import get_metric
prediction_stability_score = get_metric("prediction_stability_score")
# **********************************************************************
# --- Input Validation and Preparation ---
# y_pred: (T,), (N,T), or (N,O,T)
y_pred_arr = check_array(
y_pred, ensure_2d=False, allow_nd=True,
dtype="numeric", force_all_finite=False, copy=True
)
# Reshape for consistent processing: (N, O, T)
y_pred_ndim_orig = y_pred_arr.ndim
if y_pred_ndim_orig == 1: # (T,)
y_pred_proc = y_pred_arr.reshape(1, 1, -1)
elif y_pred_ndim_orig == 2: # (N, T)
y_pred_proc = y_pred_arr.reshape(y_pred_arr.shape[0], 1, -1)
elif y_pred_ndim_orig == 3: # (N, O, T)
y_pred_proc = y_pred_arr
else:
raise ValueError(
"y_pred must be 1D, 2D (n_samples, n_timesteps), or 3D "
"(n_samples, n_outputs, n_timesteps)."
)
n_samples, n_outputs, n_timesteps = y_pred_proc.shape
if n_samples == 0:
warnings.warn("Input y_pred is empty. Cannot generate plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "PSS Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
if n_timesteps < 2 and kind != 'summary_bar':
# summary_bar might use precomputed values
warnings.warn(
"PSS requires at least 2 time steps for histogram. "
"Plot may be empty or misleading."
)
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or f"PSS {kind} (Not Enough Timesteps)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
plot_title_str = title # Use a mutable string
# --- Metric Calculation Handling ---
current_metric_kws = metric_kws or {}
default_kws_for_metric = {
'nan_policy': 'propagate',
'multioutput': 'uniform_average', # Default for overall score
'verbose': 0 # Metric's internal verbose
}
# --- Plotting Logic ---
if kind == 'scores_histogram':
# For histogram, we need PSS per trajectory (per sample, per output).
# The prediction_stability_score function's internal `pss_per_trajectory`
# (before sample averaging) is what we need.
# To get these, we call prediction_stability_score in a way that
# allows us to capture these intermediate values if not directly exposed.
# However, the refactored `prediction_stability_score` calculates
# `pss_per_trajectory` of shape (N_calc, O). This is suitable.
# We need to calculate these raw per-trajectory scores.
# `pss_values` if provided should be these raw scores.
per_trajectory_scores: Optional[np.ndarray] = None
if pss_values is not None:
# Assume pss_values, if provided for histogram, are already
# per-trajectory (N,O) or (N,) if single output.
if pss_values.ndim == 1 and n_outputs == 1 and \
len(pss_values) == n_samples :
per_trajectory_scores = pss_values.reshape(-1,1) # Ensure (N,1)
elif pss_values.ndim == 2 and \
pss_values.shape[0] == n_samples and \
pss_values.shape[1] == n_outputs:
per_trajectory_scores = pss_values
else:
warnings.warn(
"Provided `pss_values` shape incompatible for histogram. "
"Recalculating per-trajectory scores."
)
if per_trajectory_scores is None:
# Calculate per-trajectory scores.
# This requires a bit of care as the main metric returns aggregated scores.
# We can simulate by calling for each sample if needed, or use internals.
# The current `prediction_stability_score` structure:
# y_pred_proc (N,O,T) -> diffs (N,O,T-1) -> pss_per_trajectory (N,O)
# This `pss_per_trajectory` is exactly what we need.
# So, we can effectively run parts of the metric logic here.
nan_policy_hist = current_metric_kws.get(
'nan_policy', default_kws_for_metric['nan_policy']
)
temp_y_pred = y_pred_proc.copy() # Use copy for local NaN handling
nan_mask_sot = np.isnan(temp_y_pred)
nan_mask_so_hist = nan_mask_sot.any(axis=2) # (N,O)
if np.any(nan_mask_so_hist):
if nan_policy_hist == 'raise':
raise ValueError(
"NaNs found in y_pred for histogram."
)
elif nan_policy_hist == 'omit':
rows_with_nan = nan_mask_so_hist.any(axis=1) # (N,)
rows_to_keep = ~rows_with_nan
if not np.any(rows_to_keep):
ax.text(0.5,0.5,"All samples omitted due to NaNs.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "PSS Scores (No Data)")
return ax
temp_y_pred = temp_y_pred[rows_to_keep]
nan_mask_so_hist = nan_mask_so_hist[rows_to_keep]
if temp_y_pred.shape[0] == 0 or temp_y_pred.shape[2] < 2:
ax.text(0.5,0.5,"Not enough data for PSS histogram.",
ha='center',va='center',transform=ax.transAxes)
if show_grid: ax.grid(**(grid_props or {}))
ax.set_title(plot_title_str or "PSS Scores (No Data)")
return ax
diffs_hist = np.abs(
temp_y_pred[..., 1:] - temp_y_pred[..., :-1]
)
per_trajectory_scores = np.mean(diffs_hist, axis=2) # (N_calc, O)
if nan_policy_hist == 'propagate':
per_trajectory_scores = np.where(
nan_mask_so_hist, np.nan, per_trajectory_scores
)
# Select output for histogram
scores_to_plot_hist: np.ndarray
current_output_label = ""
if n_outputs > 1:
if output_idx is None:
raise ValueError(
"For multi-output data and kind='scores_histogram', "
"'output_idx' must be specified."
)
if not (0 <= output_idx < n_outputs):
raise ValueError(
f"output_idx {output_idx} out of bounds for "
f"{n_outputs} outputs."
)
scores_to_plot_hist = per_trajectory_scores[:, output_idx]
current_output_label = f" (Output {output_idx})"
else: # Single output
scores_to_plot_hist = per_trajectory_scores.ravel()
valid_scores_for_hist = scores_to_plot_hist[
~np.isnan(scores_to_plot_hist)
]
if valid_scores_for_hist.size > 0:
ax.hist(valid_scores_for_hist, bins=hist_bins,
color=hist_color, edgecolor=hist_edgecolor,
**kwargs.get('hist_kwargs', {}))
if show_score:
mean_of_plotted_pss = np.mean(valid_scores_for_hist)
score_text = f"Mean PSS: {mean_of_plotted_pss:.4f}"
current_title = plot_title_str or \
"Distribution of PSS Values"
plot_title_str = (
f"{current_title}{current_output_label}\n({score_text})"
)
else:
ax.text(0.5,0.5, "No valid PSS values for histogram.",
ha='center', va='center', transform=ax.transAxes)
current_title = plot_title_str or \
"Distribution of PSS Values"
plot_title_str = f"{current_title}{current_output_label} (No Data)"
ax.set_xlabel(xlabel or 'PSS per Trajectory')
ax.set_ylabel(ylabel or 'Frequency')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
elif kind == 'summary_bar':
scores_to_plot_bar: Union[float, np.ndarray]
if pss_values is not None:
scores_to_plot_bar = pss_values
if verbose > 0:
print(f"Using pre-computed PSS values: {scores_to_plot_bar}")
else:
summary_bar_default_kws = default_kws_for_metric.copy()
if 'multioutput' not in current_metric_kws:
summary_bar_default_kws['multioutput'] = 'uniform_average'
effective_kws = {**summary_bar_default_kws, **current_metric_kws}
cleaned_kws = _get_valid_kwargs(
prediction_stability_score, effective_kws
)
scores_to_plot_bar = prediction_stability_score(
y_pred_arr, # Use original y_pred_arr for metric call
**cleaned_kws
)
if verbose > 0:
print(f"Computed PSS for summary: {scores_to_plot_bar}")
scores_arr_bar: np.ndarray
x_labels_bar: List[str]
multioutput_used = (current_metric_kws or {}).get(
'multioutput', default_kws_for_metric['multioutput'])
if np.isscalar(scores_to_plot_bar) or \
(isinstance(scores_to_plot_bar, np.ndarray) and \
scores_to_plot_bar.ndim == 0):
scores_arr_bar = np.array([scores_to_plot_bar])
x_labels_bar = ['Mean PSS']
elif isinstance(scores_to_plot_bar, np.ndarray) and \
scores_to_plot_bar.ndim == 1:
scores_arr_bar = scores_to_plot_bar
x_labels_bar = [f'Output {i}' for i in range(len(scores_arr_bar))]
if plot_title_str and multioutput_used == 'raw_values':
plot_title_str += " (Per Output)"
else:
raise TypeError(
f"Unexpected type/shape for PSS scores: {type(scores_to_plot_bar)}"
)
bars = ax.bar(x_labels_bar, scores_arr_bar, color=bar_color,
width=bar_width, **kwargs.get('bar_kwargs', {}))
ax.set_ylabel(ylabel or 'Prediction Stability Score (PSS)')
if scores_arr_bar.size > 0 and not np.all(np.isnan(scores_arr_bar)):
min_val = np.nanmin(scores_arr_bar)
max_val = np.nanmax(scores_arr_bar)
padding = 0.1 * (max_val - min_val) if (max_val-min_val)>1e-6 else 0.1
ax.set_ylim(min(0, min_val - padding), max_val + padding + 0.05)
for bar_obj in bars:
yval = bar_obj.get_height()
if not np.isnan(yval):
x = bar_obj.get_x() + bar_obj.get_width() / 2.0
dy = 3 if yval >= 0 else -10
va = 'bottom' if yval >= 0 else 'top'
ax.annotate(
score_annotation_format.format(yval),
xy=(x, yval),
xytext=(0, dy),
textcoords="offset points",
ha="center",
va=va,
)
else:
raise ValueError(
f"Unknown plot kind: '{kind}'. Choose 'scores_histogram' "
"or 'summary_bar'."
)
if plot_title_str:
ax.set_title(plot_title_str)
if show_grid:
current_grid_props = grid_props if grid_props is not None \
else {'linestyle': ':', 'alpha': 0.7}
ax.grid(**current_grid_props)
else:
ax.grid(False)
return ax
plot_prediction_stability.__doc__=r"""
Visualise the Prediction Stability Score (PSS) — the average
absolute change between successive time steps in a forecast
trajectory. Lower PSS ⇒ smoother (more stable) predictions.
Two complementary views are provided:
* **'scores_histogram'** – distribution of per‑trajectory PSS values
for a chosen output.
* **'summary_bar'** – bar chart of the mean PSS (overall or one bar
per output).
Parameters
----------
y_pred : ndarray
Model predictions. Accepts
* 1‑D ``(T,)`` – single trajectory, one output;
* 2‑D ``(N, T)`` – *N* trajectories, one output;
* 3‑D ``(N, O, T)`` – *N* trajectories, *O* outputs.
The final dimension is the temporal axis (*T ≥ 2* for PSS).
pss_values : float or ndarray, optional
Pre‑computed PSS value(s). If supplied the helper skips internal
calls to
:func:`fusionlab.metrics.prediction_stability_score`.
metric_kws : dict, optional
Extra keyword arguments forwarded to the metric function.
kind : {{'scores_histogram', 'summary_bar'}},
default ``'summary_bar'``
Select the visualisation style.
output_idx : int, optional
Output dimension to plot when ``kind='scores_histogram'`` on
multi‑output data.
hist_bins : int | sequence | str, default ``'auto'``
hist_color : str, default ``'teal'``
hist_edgecolor : str, default ``'black'``
Styling options for the histogram.
{params.base.figsize}
{params.base.title}
{params.base.xlabel}
{params.base.ylabel}
{params.base.bar_color}
{params.base.bar_width}
{params.base.score_annotation_format}
show_score : bool, default ``True``
Display the mean PSS on the histogram title.
{params.base.show_grid}
{params.base.grid_props}
{params.base.ax}
{params.base.verbose}
{params.base.kwargs}
Returns
-------
matplotlib.axes.Axes
Axes containing the stability visualisation.
Notes
-----
For one trajectory :math:`(\hat y_{{1}},\dots,\hat y_{{T}})` the stability
score is
.. math::
\operatorname{{PSS}}
\;=\;
\frac{{1}}{{T-1}}\sum_{{t=2}}^{{T}}
\bigl|\hat y_{{t}} - \hat y_{{t-1}}\bigr|.
The helper first reshapes ``y_pred`` to *(N, O, T)*, computes the
per‑trajectory scores, and then aggregates or plots them according to
``kind``.
Examples
--------
>>> import numpy as np, matplotlib.pyplot as plt
>>> from fusionlab.plot.evaluation import plot_prediction_stability
>>> rng = np.random.default_rng(0)
>>> preds = rng.normal(size=(200, 30)) # 200 series, 30 time steps
>>> plot_prediction_stability(
... preds, kind='scores_histogram', figsize=(8, 4))
>>> plt.show()
See Also
--------
fusionlab.metrics.prediction_stability_score
Numeric implementation of PSS.
fusionlab.plot.evaluation.plot_time_weighted_metric
Time‑weighted MAE, accuracy, and interval‑score plots.
References
----------
.. [1] Hyndman, R.J. & Athanasopoulos, G. *Forecasting: Principles and
Practice*, 3rd ed., OTexts, 2021 — section 2.6, “Stability”.
""".format(params=_param_docs)
def plot_qce_donut(
df: pd.DataFrame,
actual_col: str,
quantile_cols: List[str],
quantile_levels: List[float],
metric_kws: Optional[Dict[str, Any]] = None,
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = "Quantile Calibration Error Contributions",
colors: Optional[List[str]] = None,
center_text_format: str = "Avg QCE:\n{:.4f}",
segment_label_format: str = "{name}\n({value:.2f})", # {name}, {value}, {percent}
startangle: float = 90,
counterclock: bool = False,
wedgeprops: Optional[Dict[str, Any]] = None,
donut_width: float = 0.4, # Width of the donut ring
value_annotations: bool=True,
# Legend and labels
show_legend: bool = True,
legend_title: Optional[str] = "Quantiles",
legend_loc: str = "center left",
legend_bbox_to_anchor: Tuple[float, float] = (0.95, 0.5),
# Common params
# Grid and labels
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any # For future matplotlib extensions
) -> plt.Axes:
"""
Visualizes Quantile Calibration Error (QCE) components as a donut chart.
(Full docstring to be added later for detailed parameter explanation)
"""
df, quantile_levels = _validate_qce_plot_inputs(
df, actual_col=actual_col,
quantile_cols= quantile_cols,
quantile_levels= quantile_levels,
error_policy= 'raise'
)
y_true_np = df[actual_col].to_numpy(dtype=float)
y_pred_quantiles_np = df[quantile_cols].to_numpy(dtype=float) # (N, Q)
quantile_levels_np = np.array(quantile_levels, dtype=float) # (Q,)
n_samples, n_quantiles = y_pred_quantiles_np.shape
if n_samples != len(y_true_np):
raise ValueError("Length mismatch: actual_col vs quantile_cols.")
# --- 2. Handle NaNs and Sample Weights (from metric_kws) ---
current_metric_kws = metric_kws or {}
nan_policy = current_metric_kws.get('nan_policy', 'propagate')
sample_weight = current_metric_kws.get('sample_weight', None)
eps = current_metric_kws.get('eps', 1e-8)
# Create a mask for NaNs across y_true and all relevant y_pred_quantiles
# nan_mask_samples is 1D (n_samples,)
nan_mask_true = np.isnan(y_true_np)
nan_mask_preds = np.isnan(y_pred_quantiles_np).any(axis=1)
combined_nan_mask_samples = nan_mask_true | nan_mask_preds
y_t_calc = y_true_np
y_p_calc = y_pred_quantiles_np
s_weights_calc = sample_weight
if np.any(combined_nan_mask_samples):
if nan_policy == 'raise':
raise ValueError("NaNs found in input data.")
elif nan_policy == 'propagate':
# If any NaN leads to an overall NaN result for donut chart
warnings.warn(
"NaNs found with nan_policy='propagate'. Donut chart "
"may not be meaningful if miscalibrations become NaN."
)
# Calculations below will propagate NaNs
elif nan_policy == 'omit':
if verbose > 0:
print("NaNs detected. Omitting samples with NaNs.")
rows_to_keep = ~combined_nan_mask_samples
if not np.any(rows_to_keep):
if verbose > 0: warnings.warn("All samples omitted due to NaNs.")
# Create an empty plot or return early
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "QCE Donut (No Data)")
return ax # type: ignore
y_t_calc = y_true_np[rows_to_keep]
y_p_calc = y_pred_quantiles_np[rows_to_keep]
if s_weights_calc is not None:
s_weights_calc = check_array(s_weights_calc, ensure_2d=False,
dtype="numeric", force_all_finite=True)
check_consistent_length(y_true_np, s_weights_calc)
s_weights_calc = s_weights_calc[rows_to_keep]
if y_t_calc.shape[0] == 0: # All samples omitted or original empty
if verbose > 0: warnings.warn("No valid samples for QCE calculation.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "QCE Donut (No Valid Data)")
return ax # type: ignore
# --- 3. Calculate Per-Quantile Miscalibration ---
indicators = (y_t_calc[:, np.newaxis] <= y_p_calc).astype(float)
observed_proportions_q: np.ndarray
if s_weights_calc is not None:
sum_sw = np.sum(s_weights_calc)
if sum_sw < eps:
warnings.warn(f"Sum of sample_weight ({sum_sw}) < eps ({eps}). "
"Observed proportions may be unstable or NaN.")
observed_proportions_q = np.full(n_quantiles, np.nan)
else:
# Weighted average, careful with NaNs in indicators if propagate was used
# and some y_t_calc or y_p_calc were NaN
# Assuming NaNs in indicators will propagate with np.average if weights are numbers
observed_proportions_q = np.average(
indicators, axis=0, weights=s_weights_calc
)
else:
observed_proportions_q = np.nanmean(indicators, axis=0)
miscalibrations_q = np.abs(
observed_proportions_q - quantile_levels_np
) # Shape (Q,)
# Handle cases where all miscalibrations are NaN or zero
valid_miscal_mask = ~np.isnan(miscalibrations_q)
if not np.any(valid_miscal_mask) or \
np.sum(miscalibrations_q[valid_miscal_mask]) < eps :
# If all are NaN or sum is effectively zero, donut chart is not meaningful
warnings.warn(
"All per-quantile miscalibrations are NaN or sum to near zero. "
"Donut chart cannot be generated meaningfully."
)
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "QCE Donut (No Miscalibration / Data Issues)")
# Add text to center indicating the situation
center_text = "No Miscalibration\nor Data Issues"
if np.any(np.isnan(miscalibrations_q)):
center_text = "Data Issues\n(NaNs)"
elif np.sum(miscalibrations_q[valid_miscal_mask]) < eps :
center_text = "Perfect Calibration\n(Avg QCE ~ 0)"
ax.text(0.5, 0.5, center_text, ha='center', va='center',
transform=ax.transAxes, fontsize='large')
ax.set_aspect('equal') # Ensure circle
ax.set_xticks([])
ax.set_yticks([])
return ax # type: ignore
# Use quantile_cols for labels if available and match length
segment_names = quantile_cols if len(quantile_cols) == n_quantiles \
else [f"q={q:.2f}" for q in quantile_levels_np]
# Filter out NaN miscalibrations for plotting pie
plot_miscalibrations = miscalibrations_q[valid_miscal_mask]
plot_segment_names = [
name for i, name in enumerate(segment_names) if valid_miscal_mask[i]
]
plot_colors = None
if colors:
plot_colors = [
c for i,c in enumerate(colors) if valid_miscal_mask[i]
] if len(colors) == n_quantiles else colors
# --- 4. Plotting ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
ax.set_title(title or "QCE Donut")
wedges, texts, autotexts = ax.pie(
plot_miscalibrations,
labels=None, # Labels handled by legend or custom placement
autopct=lambda pct: segment_label_format.format(
name="", # Name will be in legend
value= (pct/100.)*np.sum(plot_miscalibrations), # Value this segment represents
percent=pct
) if value_annotations else None,
startangle=startangle,
counterclock=counterclock,
colors=plot_colors,
wedgeprops=wedgeprops or dict(width=donut_width, edgecolor='w')
)
# Center circle to make it a donut
center_circle = plt.Circle((0,0), 1-donut_width, fc='white')
ax.add_artist(center_circle)
# Text in the center
avg_qce = np.nanmean(miscalibrations_q) # Mean of non-NaN miscalibrations
if not np.isnan(avg_qce):
center_label = center_text_format.format(avg_qce)
ax.text(0, 0, center_label, ha='center', va='center',
fontsize='large', fontweight='bold')
ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle.
if show_legend:
# Create legend with proper labels (quantile levels or names)
# Use plot_segment_names which correspond to plot_miscalibrations
legend_handles = []
for i, w in enumerate(wedges):
# Ensure color is RGBA for legend patch
face_color = w.get_facecolor()
if isinstance(face_color, tuple) and len(face_color) == 4:
patch_color = face_color
else: # Convert if it's a named color string or RGB
patch_color = to_rgba(face_color) # type: ignore
legend_handles.append(
plt.Rectangle((0,0),1,1, facecolor=patch_color)
)
ax.legend(
legend_handles,
plot_segment_names,
title=legend_title,
loc=legend_loc,
bbox_to_anchor=legend_bbox_to_anchor
)
# Grid is not typically used for pie/donut charts
if show_grid:
if verbose > 0:
warnings.warn("'show_grid' is True, but grids are not "
"standard for donut charts.")
# ax.grid(**(grid_props or {})) # Usually off for pie
# Remove default ticks and labels for pie chart axes
ax.set_xticks([])
ax.set_yticks([])
return ax
def _validate_qce_plot_inputs(
df: pd.DataFrame,
actual_col: str,
quantile_cols: List[str],
quantile_levels: List[Real],
error_policy: Literal['raise', 'warn', 'ignore'] = 'raise'
) -> Tuple[pd.DataFrame, np.ndarray]:
"""
Validates inputs for QCE plotting functions.
Checks DataFrame type, column existence, and quantile properties.
Parameters
----------
df : pd.DataFrame
Input DataFrame containing actual and predicted quantile values.
actual_col : str
Name of the column containing true observed values.
quantile_cols : List[str]
List of column names corresponding to predicted quantiles.
quantile_levels : List[Real]
List of nominal quantile levels (e.g., [0.1, 0.5, 0.9]).
error_policy : {'raise', 'warn', 'ignore'}, default='raise'
Policy for handling feature existence errors from `exist_features`.
Returns
-------
Tuple[pd.DataFrame, np.ndarray]
The validated input DataFrame and a NumPy array of validated
and processed quantile levels.
Raises
------
TypeError
If `df` is not a pandas DataFrame, or if `quantile_cols`
or `quantile_levels` are not lists of the correct types.
ValueError
If columns are missing (and `error_policy='raise'`), if
lengths of `quantile_cols` and `quantile_levels` mismatch,
or if `quantile_levels` are not strictly between 0 and 1.
"""
if not isinstance(df, pd.DataFrame):
# For non-existence errors, _report_condition is better,
# but for fundamental type errors, direct raise is common.
# Match this with how _report_condition is used elsewhere.
# For now, direct raise for type errors on df.
raise TypeError("Input 'df' must be a pandas DataFrame.")
# Validate existence of actual_col
exist_features(df, features=actual_col, error=error_policy)
# Validate quantile_cols type and existence
if not (isinstance(quantile_cols, list) and
all(isinstance(qc, str) for qc in quantile_cols)):
raise TypeError(
"'quantile_cols' must be a list of strings."
)
if not quantile_cols: # Check if the list is empty
raise ValueError("'quantile_cols' cannot be empty.")
exist_features(df, features=quantile_cols, error=error_policy)
# Validate and process quantile_levels
# We expect quantile_levels to be numeric and in (0,1) for QCE.
# validate_quantiles with mode='strict' ensures they are in [0,1].
# An additional check for strict (0,1) is needed.
try:
# Use asarray=True to get a NumPy array for easier processing
# Pass round_digits and dtype if they are relevant from a higher context
# or use defaults within validate_quantiles.
# For QCE, high precision is good, so float64 might be better if available.
validated_levels_np = validate_quantiles(
quantile_levels,
asarray=True,
mode="strict", # Ensures values are in [0,1] and numeric
# Default round_digits and dtype from validate_quantiles used
)
except (TypeError, ValueError) as e:
# Catch errors from validate_quantiles (e.g., non-numeric, out of [0,1])
# Re-raise as ValueError for consistency, or let original error propagate
raise ValueError(
f"Validation of 'quantile_levels' failed: {e}"
) from e
# After validate_quantiles(mode='strict'), levels are in [0,1].
# Check for strict (0,1) as QCE is typically not for 0 or 1.
are_all_values_in_bounds(
validated_levels_np, bounds =(0, 1), closed='neither',
message =(
"All 'quantile_levels' must be strictly between 0 and 1 "
"(exclusive of 0 and 1)."
)
)
# Check for length consistency
if len(quantile_cols) != len(validated_levels_np):
raise ValueError(
"Length of 'quantile_cols' ({}) must match the length of "
"validated 'quantile_levels' ({}).".format(
len(quantile_cols), len(validated_levels_np)
)
)
# df is returned as is, as it's not modified by these checks,
# but its columns are confirmed to exist.
# validated_levels_np is returned as it's processed.
return df, validated_levels_np
def plot_radar_scores(
data_values: Optional[Union[List[Real], Dict[str, Real], Real]] = None,
category_names: Optional[List[str]] = None,
y_true: Optional[np.ndarray] = None,
y_pred: Optional[np.ndarray] = None,
metric_functions: Optional[Union[
MetricFunctionType, List[MetricFunctionType]]] = None,
metric_kwargs_list: Optional[Union[
Dict[str, Any], List[Dict[str, Any]]]] = None,
normalize_values: bool = False,
plot_target_type: Literal['metric'] = 'metric', # For future expansion
# Plotting customizations
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = "Metric Scores Radar Plot",
value_annotations: bool = True,
annotation_format: str = "{:.2f}",
fill_radar: bool = True,
fill_alpha: float = 0.25,
line_color: Optional[str] = None, # Auto-cycles if multiple lines
line_width: float = 2,
marker: Optional[str] = 'o',
# Radial axis customization
r_min: Optional[Real] = None,
r_max: Optional[Real] = None,
r_ticks_count: int = 5,
# Grid and labels
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
category_label_props: Optional[Dict[str, Any]] = None,
value_label_props: Optional[Dict[str, Any]] = None,
legend_label: Optional[str] = None, # For when plotting single entity
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any # For future matplotlib extensions
) -> plt.Axes:
"""
Generates a radar plot to visualize multiple scores or attributes.
Primarily designed for comparing metric scores.
"""
# --- 1. Input Processing and Score Calculation ---
final_values_to_plot: np.ndarray
final_category_names: List[str]
if data_values is not None:
if isinstance(data_values, dict):
final_category_names = list(data_values.keys())
final_values_to_plot = np.array(list(data_values.values()),
dtype=float)
else: # List, scalar, or array-like
# Use column_manager to handle scalar or list-like
managed_values = columns_manager(
data_values, # force_array=True, to_list=False
)
if hasattr(managed_values, '__iter__'):
managed_values = np.array ( managed_values)
if managed_values is None or managed_values.size == 0:
raise ValueError(
"If 'data_values' is not a dict, it must not be empty "
"after processing."
)
final_values_to_plot = managed_values.astype(float)
if category_names is not None:
if len(category_names) < len(final_values_to_plot):
if verbose > 0:
warnings.warn(
"Length of 'category_names' is less than "
"'data_values'. Auto-generating remaining names."
)
base_name = f"{plot_target_type}_"
num_missing = len(final_values_to_plot) - len(category_names)
auto_names = [f"{base_name}{i+len(category_names)+1}"
for i in range(num_missing)]
final_category_names = category_names + auto_names
elif len(category_names) > len(final_values_to_plot):
if verbose > 0:
warnings.warn(
"Length of 'category_names' is greater than "
"'data_values'. Truncating 'category_names'."
)
final_category_names = category_names[:len(final_values_to_plot)]
else:
final_category_names = category_names
else: # Auto-generate all names
base_name = f"{plot_target_type}_"
final_category_names = [f"{base_name}{i+1}" for i in
range(len(final_values_to_plot))]
elif plot_target_type == 'metric':
if y_true is None or y_pred is None:
raise ValueError(
"If 'data_values' is None and 'plot_target_type' is 'metric',"
" 'y_true' and 'y_pred' must be provided."
)
y_t = check_array(y_true, ensure_2d=False, force_all_finite=True,
dtype="numeric")
y_p = check_array(y_pred, ensure_2d=False, force_all_finite=True,
dtype="numeric")
check_consistent_length(y_t, y_p)
# Default metric functions if none provided
if metric_functions is None:
metric_funcs_to_use: List[MetricFunctionType] = [
mean_absolute_error,
lambda yt, yp, **kws: np.sqrt(mean_squared_error(yt, yp, **kws)), # RMSE
mean_absolute_percentage_error
]
default_names = ["MAE", "RMSE", "MAPE"]
# Ensure default kwargs for RMSE's mean_squared_error if any
default_metric_kws_list: List[Dict] = [{}, {'squared': False}, {}]
# User can override these defaults via metric_kwargs_list
if metric_kwargs_list is None:
metric_kws_list_proc = default_metric_kws_list
elif isinstance(metric_kwargs_list, dict): # Apply to all defaults
metric_kws_list_proc = [
{**dkw, **metric_kwargs_list} for dkw in default_metric_kws_list
]
elif isinstance(metric_kwargs_list, list) and \
len(metric_kwargs_list) == len(metric_funcs_to_use):
metric_kws_list_proc = [
{**dkw, **ukw} for dkw, ukw in zip(
default_metric_kws_list, metric_kwargs_list) # type: ignore
]
else:
raise ValueError(
"Invalid 'metric_kwargs_list' for default functions.")
final_category_names = category_names if category_names and \
len(category_names)==len(default_names) else default_names
elif callable(metric_functions): # Single function
metric_funcs_to_use = [metric_functions]
metric_kws_list_proc = [metric_kwargs_list or {}] \
if isinstance(metric_kwargs_list, dict) or metric_kwargs_list is None \
else metric_kwargs_list # Should be list of one dict
if category_names:
final_category_names = category_names
else:
func_name = getattr(metric_functions, '__name__', 'metric_1')
final_category_names = [func_name]
elif isinstance(metric_functions, list): # List of functions
metric_funcs_to_use = metric_functions
if metric_kwargs_list is None:
metric_kws_list_proc = [{} for _ in metric_funcs_to_use]
elif isinstance(metric_kwargs_list, list) and \
len(metric_kwargs_list) == len(metric_funcs_to_use):
metric_kws_list_proc = metric_kwargs_list
else:
raise ValueError(
"'metric_kwargs_list' must be a list of dicts matching "
"'metric_functions' length, or None."
)
if category_names and len(category_names) == len(metric_funcs_to_use):
final_category_names = category_names
elif category_names is None:
final_category_names = [
getattr(f, '__name__', f"{plot_target_type}_{i+1}")
for i, f in enumerate(metric_funcs_to_use)
]
else:
raise ValueError(
"Length of 'category_names' must match 'metric_functions'."
)
else:
raise TypeError(
"'metric_functions' must be None, a callable, or list of callables.")
# Compute scores
computed_scores = []
for func, kws_for_func in zip(metric_funcs_to_use, metric_kws_list_proc):
valid_kws = _get_valid_kwargs(func, kws_for_func)
try:
score = func(y_t, y_p, **valid_kws)
computed_scores.append(score)
except Exception as e:
warnings.warn(
"Error computing metric "
f"{getattr(func,'__name__','unknown')}:"
f" {e}. Skipping."
)
computed_scores.append(np.nan)
final_values_to_plot = np.array(
computed_scores, dtype=float
)
else:
raise ValueError(
f"Unsupported 'plot_target_type': {plot_target_type}. "
"Currently only 'metric' is supported."
)
if final_values_to_plot.ndim > 1:
# This implies multiple sets of values (e.g., from multi-output metrics)
# The current design plots one radar line.
# For now, require scalar values per category.
raise ValueError(
"Each category for the radar plot must correspond to a single "
"scalar value. Received multi-dimensional values."
)
if len(final_values_to_plot) != len(final_category_names):
raise ValueError(
"Mismatch between number of values to plot and category names. "
f"Got {len(final_values_to_plot)} values and "
f"{len(final_category_names)} names."
)
num_vars = len(final_category_names)
if num_vars < 3:
warnings.warn(
"Radar plots are typically used for 3 or more categories. "
"Consider a bar chart for fewer categories."
)
# Fallback or proceed? For now, proceed if user insists.
if num_vars == 0:
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(title or "Radar Plot (No Data)")
if show_grid: ax.grid(**(grid_props or {}))
return ax
# --- 2. Normalization (if requested) ---
plot_values = final_values_to_plot.copy()
original_values_for_annotation = final_values_to_plot.copy()
# Handle NaNs before normalization or plotting
nan_mask = np.isnan(plot_values)
if np.all(nan_mask): # All values are NaN
warnings.warn("All values for radar plot are NaN. Plot will be empty.")
# Proceed to draw empty radar axes
# Replace NaNs with a value for plotting structure, but they won't be "seen"
# if we are careful with plotting (e.g., don't connect across NaNs).
# Or, for simplicity in radar, they might be plotted at 0 or min.
# For now, let's make them 0 for structure, but annotations will show NaN.
plot_values_for_structure = np.nan_to_num(plot_values, nan=0.0)
if normalize_values:
# Min-max scale to [0, 1] for better shape comparison
# Only use non-NaN values for finding min/max
valid_plot_vals = plot_values[~nan_mask]
if valid_plot_vals.size > 0:
min_val = np.min(valid_plot_vals)
max_val = np.max(valid_plot_vals)
if max_val - min_val > 1e-8: # Avoid division by zero if all same
# Apply scaling to non-NaN original values
scaled_non_nan = (valid_plot_vals - min_val) / (max_val - min_val)
# Put scaled values back, keep NaNs as NaNs in plot_values
# plot_values will be used for plotting line/fill
# original_values_for_annotation keeps original scale for text
temp_scaled_values = np.full_like(plot_values, np.nan)
temp_scaled_values[~nan_mask] = scaled_non_nan
plot_values = temp_scaled_values
plot_values_for_structure = np.nan_to_num(plot_values, nan=0.0)
if verbose > 0:
print(f"Values normalized. Original range: [{min_val:.2f}, {max_val:.2f}].")
elif verbose > 0: # All valid values are the same
warnings.warn(
"All valid values are identical after NaN handling; "
"normalization to [0,1] range results in all zeros or ones. "
"Radar shape might not be informative."
)
# Set all to 0.5 to show a regular polygon if all same
plot_values[~nan_mask] = 0.5
plot_values_for_structure = np.nan_to_num(plot_values, nan=0.0)
else: # All values were NaN
if verbose > 0: warnings.warn("All values are NaN, cannot normalize.")
# If normalized, radial ticks should be 0 to 1
if r_min is None: r_min = 0
if r_max is None: r_max = 1
# --- 3. Radar Plot Creation ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True)) # type: ignore
elif not hasattr(ax, 'plot'): # Check if it's a Matplotlib Axes
raise TypeError("`ax` must be a Matplotlib Axes object.")
# If ax is provided, assume it's already polar if needed, or make it so.
# This is tricky. Standard practice is to create polar on a figure.
# For simplicity, if ax is passed, we'll try to use it as is.
# User should ensure it's a polar Axes if passing one.
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
# Make the plot close by appending the first value and angle
plot_data_closed = np.concatenate(
(plot_values_for_structure, [plot_values_for_structure[0]]))
angles_closed = angles + [angles[0]]
original_values_closed = np.concatenate( # noqa
(original_values_for_annotation, [original_values_for_annotation[0]])
)
nan_mask_closed = np.concatenate((nan_mask, [nan_mask[0]]))
# Plotting the data line
# To handle NaNs gracefully (not connecting across them):
# Plot segments between non-NaN points.
segments = np.ma.masked_where( # noqa
nan_mask_closed, plot_data_closed).tolist()
# For a single radar line, line_color can be a string
# If comparing multiple entities on same radar later, this would need cycling
current_line_color = line_color if line_color is not None else \
next(ax._get_lines.prop_cycler)['color'] # type: ignore
ax.plot(angles_closed, plot_data_closed, color=current_line_color,
linewidth=line_width, marker=marker if marker else '',
label=legend_label if legend_label else plot_target_type.title())
if fill_radar:
ax.fill(angles_closed, plot_data_closed, color=current_line_color,
alpha=fill_alpha)
# Category labels
ax.set_xticks(angles)
category_label_defaults = {'size': 'medium'}
category_label_final_props = {
**category_label_defaults,
**(category_label_props or {})
}
ax.set_xticklabels(final_category_names, **category_label_final_props)
# Radial axis (y-axis in polar)
if r_min is not None or r_max is not None:
ax.set_ylim(r_min, r_max)
# Set radial ticks. MaxNLocator helps get "nice" ticks.
# Get current ylim if not set by user, to inform tick locator
# from matplotlib.ticker import MaxNLocator
current_r_lim = ax.get_ylim()
# Try your preferred prune option, otherwise fall back
try:
locator = MaxNLocator(nbins=r_ticks_count, prune='min')
except ValueError:
# 'min' isn’t supported; fall back to no pruning (or use 'lower'/'upper' as you see fit)
locator = MaxNLocator(nbins=r_ticks_count)
# Now generate your tick locations
r_tick_locs = locator.tick_values(current_r_lim[0], current_r_lim[1])
# current_r_lim = ax.get_ylim()
# r_tick_locs = MaxNLocator(
# nbins=r_ticks_count, prune='min' # 'min' to ensure 0 is often a tick
# ).tick_values(current_r_lim[0], current_r_lim[1])
# Filter out ticks outside the explicit r_min/r_max if they were set
if r_min is not None:
r_tick_locs = r_tick_locs[r_tick_locs >= r_min]
if r_max is not None:
r_tick_locs = r_tick_locs[r_tick_locs <= r_max]
# Ensure 0 is a tick if within range, and if r_min is not set above 0
if (r_min is None or r_min <=0) and 0 not in r_tick_locs and current_r_lim[0] <=0:
r_tick_locs = np.unique(np.sort(np.concatenate(([0], r_tick_locs))))
ax.set_yticks(r_tick_locs)
value_label_defaults = {'size': 'small'}
value_label_final_props = {
**value_label_defaults,
**(value_label_props or {})
}
ax.set_yticklabels(
[annotation_format.format(tick) for tick in r_tick_locs],
**value_label_final_props
)
# Value annotations on the radar points
if value_annotations:
for i, (angle, val_plot, val_orig) in enumerate(zip(
angles, plot_values, original_values_for_annotation
)):
if not np.isnan(val_plot) and not np.isnan(val_orig):
# Use original value for annotation text
annotation_text = annotation_format.format(val_orig)
ax.text(angle, val_plot, annotation_text,
ha='center', va='bottom' if val_plot >=0 else 'top',
fontsize=category_label_final_props.get('size', 'small'), # type: ignore
bbox=dict(facecolor='white', alpha=0.5,
edgecolor='none', pad=1.0)
) # Add background for readability
if title:
ax.set_title(title, va='bottom',
fontdict={'fontsize': plt.rcParams['axes.titlesize'], # type: ignore
'fontweight': plt.rcParams['axes.titleweight']}) # type: ignore
if show_grid:
grid_final_props = grid_props if grid_props is not None \
else {'linestyle': '--', 'alpha': 0.7, 'linewidth':0.5}
ax.grid(**grid_final_props)
else:
ax.grid(False)
if legend_label: # If a single entity is plotted, legend might be useful
ax.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
return ax
def _calculate_qce_miscalibrations(
y_true_period: np.ndarray, # (N_period,)
y_pred_q_period: np.ndarray, # (N_period, Q)
quantile_levels: np.ndarray, # (Q,)
sample_weight_period: Optional[np.ndarray], # (N_period,)
eps: float
) -> np.ndarray:
"""Calculates per-quantile miscalibrations for a given period."""
indicators = (
y_true_period[:, np.newaxis] <= y_pred_q_period
).astype(float) # (N_period, Q)
observed_proportions_q: np.ndarray
if sample_weight_period is not None:
sum_sw = np.sum(sample_weight_period)
if sum_sw < eps:
return np.full(quantile_levels.shape[0], np.nan)
# Weighted average, handling NaNs from inputs if they exist
# NaNs in indicators should result from NaNs in y_true or y_pred_q
temp_props_list = []
for q_idx in range(quantile_levels.shape[0]):
valid_inds_q = indicators[:, q_idx]
finite_mask_q = ~np.isnan(valid_inds_q)
if np.any(finite_mask_q):
current_weights = sample_weight_period[finite_mask_q]
sum_finite_weights = np.sum(current_weights)
if sum_finite_weights >= eps:
prop = np.sum(
valid_inds_q[finite_mask_q] * current_weights
) / sum_finite_weights
temp_props_list.append(prop)
else:
temp_props_list.append(np.nan)
else:
temp_props_list.append(np.nan)
observed_proportions_q = np.array(temp_props_list)
else:
observed_proportions_q = np.nanmean(indicators, axis=0)
miscalibrations_q = np.abs(
observed_proportions_q - quantile_levels
)
return miscalibrations_q
# --- Main Plotting Function ---
# make actual_col to be optional set to None,
# because some metrics works with the prediction only.
# also if metric is not provided and actual_col is None,
# the you plot the uncertainty distribution of quantile every years ( like average)
# and plot rather than qce instead.
# so revise
# --- Main Plotting Function ---
def plot_nested_quantiles(
df: pd.DataFrame,
quantile_cols: List[str],
quantile_levels: List[Real],
actual_col: Optional[str] = None, # Now optional
dt_col: Optional[str] = None,
periods: Optional[List[Any]] = None, # Renamed
metric_func: Optional[MetricFunctionType] = None,
metric_kws: Optional[Dict[str, Any]] = None,
# Plotting customizations
figsize: Tuple[float, float] = (10, 10),
title: Optional[str] = None, # Default title set based on mode
colors: Optional[List[str]] = None,
show_center_text: bool = True,
center_text_format: str = "Avg:\n{value:.3f}",
segment_label_format: str = "{name}\n{value:.2f}",
show_segment_labels: bool = True,
startangle: float = 90,
counterclock: bool = False,
wedgeprops: Optional[Dict[str, Any]] = None, # User can set gap here
donut_width: float = 0.3,
donut_ring_spacing: float = 0.05,
donut_base_radius: float = 0.3,
segment_explode: Optional[Union[float, List[float]]] = None,
# Legend
show_overall_legend: bool = True,
legend_title: Optional[str] = "Quantiles",
legend_loc: str = "center left",
legend_bbox_to_anchor: Tuple[float, float] = (1.05, 0.5),
# Common params
show_grid: bool=False, # for API consistency
grid_props: dict =None, # for API consistency
ax: Optional[plt.Axes] = None,
verbose: int = 0,
**kwargs: Any
) -> plt.Axes:
"""
Visualizes quantile-based metrics or average quantile values
over periods as nested donut charts.
(Full docstring to be expanded later)
"""
# --- 1. Determine Plotting Mode ---
plot_mode: Literal['qce', 'custom_metric', 'avg_quantiles']
if actual_col is None and metric_func is None:
plot_mode = 'avg_quantiles'
default_title = "" #"Average Predicted Quantile Values by Period"
elif metric_func is not None:
plot_mode = 'custom_metric'
func_name = getattr(metric_func, '__name__', 'Custom Metric')
default_title = f"{func_name} Evolution Over Periods"
else: # actual_col is provided and metric_func is None
plot_mode = 'qce'
default_title = "Quantile Calibration Error by Period"
final_title = title if title is not None else default_title
# --- 2. Input Validation ---
if not isinstance(df, pd.DataFrame):
raise TypeError("Input 'df' must be a pandas DataFrame.")
cols_to_check_existence = quantile_cols[:]
if actual_col: # Only check if provided
cols_to_check_existence.append(actual_col)
if dt_col:
cols_to_check_existence.append(dt_col)
exist_features(df, features=cols_to_check_existence, error='raise')
if not (isinstance(quantile_cols, list) and
all(isinstance(qc, str) for qc in quantile_cols)):
raise TypeError("'quantile_cols' must be a list of strings.")
if not quantile_cols:
raise ValueError("'quantile_cols' cannot be empty.")
try:
q_levels_np = validate_quantiles(
quantile_levels, asarray=True, mode="strict"
)
# For QCE and custom metrics expecting strict (0,1) quantiles
if plot_mode != 'avg_quantiles': # Avg quantiles don't impose this
if not np.all((q_levels_np > 0) & (q_levels_np < 1)):
raise ValueError(
"All 'quantile_levels' must be strictly in (0,1) "
"for metric calculation."
)
except (TypeError, ValueError) as e:
raise ValueError(
f"Validation of 'quantile_levels' failed: {e}"
) from e
if len(quantile_cols) != len(q_levels_np):
raise ValueError(
"Length of 'quantile_cols' must match 'quantile_levels'."
)
# --- 3. Data Preparation ---
metric_kws = metric_kws or {}
eps = metric_kws.get('eps', 1e-8)
sample_weight_col = metric_kws.get('sample_weight_col', None)
if sample_weight_col:
exist_features(df, features=sample_weight_col, error='raise')
if dt_col:
unique_periods = sorted(df[dt_col].unique())
periods_to_use = periods # Use renamed parameter
if periods_to_use is not None:
periods_val = [p for p in periods_to_use if p in unique_periods]
if not periods_val:
raise ValueError(
"None of the specified 'periods' found in data."
)
else:
periods_val = unique_periods
else:
periods_val = ["Overall"]
num_periods = len(periods_val)
if num_periods == 0:
warnings.warn("No periods to plot.")
if ax is None: _, ax = plt.subplots(figsize=figsize)
ax.set_title(final_title + " (No Data)")
return ax # type: ignore
# --- 4. Calculate Values per Period and Quantile ---
period_values_dict: Dict[Any, np.ndarray] = {}
for period_name in periods_val:
period_df_slice = df if dt_col is None else \
df[df[dt_col] == period_name]
if period_df_slice.empty:
period_values_dict[period_name] = np.full(len(q_levels_np), np.nan)
if verbose > 0:
warnings.warn(
f"No data for period '{period_name}'. Values set to NaN."
)
continue
y_p_q_period = period_df_slice[quantile_cols].to_numpy(dtype=float)
s_weights_period = period_df_slice[sample_weight_col].to_numpy(dtype=float) \
if sample_weight_col else None
nan_policy_metric = metric_kws.get('nan_policy', 'propagate')
# NaN handling for y_p_q_period and potentially y_t_period
# This needs to be done carefully based on the plot_mode
current_period_values: np.ndarray
if plot_mode == 'avg_quantiles':
# NaNs in y_p_q_period for this mode
nan_mask_preds_period = np.isnan(y_p_q_period) # (N_period, Q)
if np.any(nan_mask_preds_period):
if nan_policy_metric == 'raise':
raise ValueError(
f"NaNs found in quantile_cols for period '{period_name}'."
)
elif nan_policy_metric == 'omit':
# Omit rows if ANY quantile in that row is NaN
rows_with_nan_in_preds = nan_mask_preds_period.any(axis=1)
keep_rows = ~rows_with_nan_in_preds
if not np.any(keep_rows):
current_period_values = np.full(len(q_levels_np), np.nan)
else:
y_p_q_period_clean = y_p_q_period[keep_rows]
s_weights_period_clean = s_weights_period[keep_rows] \
if s_weights_period is not None else None
if s_weights_period_clean is not None and \
np.sum(s_weights_period_clean) < eps:
current_period_values = np.full(len(q_levels_np), np.nan)
else:
current_period_values = np.average(
y_p_q_period_clean, axis=0,
weights=s_weights_period_clean
)
else: # propagate
current_period_values = np.nanmean(y_p_q_period, axis=0)
else: # No NaNs in y_p_q_period
if s_weights_period is not None and np.sum(s_weights_period) < eps:
current_period_values = np.full(len(q_levels_np), np.nan)
else:
current_period_values = np.average(
y_p_q_period, axis=0, weights=s_weights_period
)
else: # 'qce' or 'custom_metric'
y_t_period = period_df_slice[actual_col].to_numpy(dtype=float) # type: ignore
nan_mask_true_period = np.isnan(y_t_period)
nan_mask_preds_period_any = np.isnan(y_p_q_period).any(axis=1)
combined_nan_samples = nan_mask_true_period | nan_mask_preds_period_any
if np.any(combined_nan_samples):
if nan_policy_metric == 'raise':
raise ValueError(
f"NaNs found in data for period '{period_name}'."
)
elif nan_policy_metric == 'omit':
keep_rows = ~combined_nan_samples
if not np.any(keep_rows):
current_period_values = np.full(len(q_levels_np), np.nan)
else:
y_t_period = y_t_period[keep_rows]
y_p_q_period = y_p_q_period[keep_rows]
if s_weights_period is not None:
s_weights_period = s_weights_period[keep_rows]
# If 'propagate', handled by metric func or helper
if y_t_period.shape[0] == 0:
current_period_values = np.full(len(q_levels_np), np.nan)
elif plot_mode == 'qce':
current_period_values = _calculate_qce_miscalibrations(
y_t_period, y_p_q_period, q_levels_np,
s_weights_period, eps
)
else: # 'custom_metric'
try:
# Custom metric func signature might vary.
# Assuming it can handle these inputs or uses _get_valid_kwargs.
# For simplicity, pass all relevant, let metric handle.
metric_args = {
'y_true_period': y_t_period,
'y_pred_q_period': y_p_q_period,
'quantile_levels': q_levels_np,
'sample_weight_period': s_weights_period,
'eps': eps
}
# Allow metric_kws to override these if needed
final_metric_args = {**metric_args, **(metric_kws or {})}
cleaned_metric_args = _get_valid_kwargs(
metric_func, final_metric_args) # type: ignore
current_period_values = metric_func(
**cleaned_metric_args) # type: ignore
except Exception as e:
warnings.warn(
f"Error calling custom metric_func for period "
f"'{period_name}': {e}. Values set to NaN."
)
current_period_values = np.full(len(q_levels_np), np.nan)
period_values_dict[period_name] = current_period_values
# --- 5. Plotting Setup ---
if ax is None:
fig, ax = plt.subplots(figsize=figsize) # type: ignore
ax.set_title(final_title)
ax.axis('equal')
ax.set_xticks([])
ax.set_yticks([])
if colors is None:
prop_cycle = plt.rcParams['axes.prop_cycle']
default_colors = prop_cycle.by_key()['color']
plot_colors = [default_colors[i % len(default_colors)]
for i in range(len(q_levels_np))]
elif len(colors) < len(q_levels_np):
warnings.warn("Not enough colors for quantiles. Colors will cycle.")
plot_colors = [colors[i % len(colors)] for i in range(len(q_levels_np))]
else:
plot_colors = colors[:len(q_levels_np)]
if segment_explode is None:
explodes = [0] * len(q_levels_np)
elif isinstance(segment_explode, float):
explodes = [segment_explode] * len(q_levels_np)
elif isinstance(segment_explode, list) and \
len(segment_explode) == len(q_levels_np):
explodes = segment_explode
else:
warnings.warn("Invalid 'segment_explode'. Using no explosion.")
explodes = [0] * len(q_levels_np)
# Default wedge properties for gap
base_wedgeprops = {'edgecolor': 'white', 'linewidth': 1.5}
if wedgeprops: # User can override/add to this
base_wedgeprops.update(wedgeprops)
# --- 6. Plotting Nested Donuts ---
current_outer_radius = donut_base_radius + \
num_periods * donut_width + \
max(0, num_periods - 1) * donut_ring_spacing
legend_items_overall = []
for i, period_name in enumerate(periods_val): # Use validated periods_val
values_for_period = period_values_dict[period_name]
valid_scores_mask = ~np.isnan(values_for_period)
if not np.any(valid_scores_mask):
if verbose > 0:
warnings.warn(
f"All values for period '{period_name}' are NaN. "
"Skipping donut for this period."
)
current_outer_radius -= (donut_width + donut_ring_spacing)
continue
plot_data_period = values_for_period[valid_scores_mask]
# Use original quantile_cols for names if plot_mode is 'avg_quantiles'
# and lengths match, otherwise generate from q_levels_np
if plot_mode == 'avg_quantiles' and \
len(quantile_cols) == len(q_levels_np):
period_segment_names_all = quantile_cols
else:
period_segment_names_all = [f"q={q:.2f}" for q in q_levels_np]
period_segment_names = [ # noqa
period_segment_names_all[j] for j, keep in enumerate(valid_scores_mask) if keep
]
period_colors = [
plot_colors[j] for j, keep in enumerate(valid_scores_mask) if keep
]
period_explodes = [
explodes[j] for j, keep in enumerate(valid_scores_mask) if keep
]
sum_plot_data = np.sum(plot_data_period)
# For pie chart, values should be positive. If metric can be negative,
# this visualization might not be ideal or needs transformation.
# Assuming metric scores (like QCE) or avg quantiles are non-negative.
if sum_plot_data < eps and not np.all(plot_data_period < eps):
if verbose > 0:
warnings.warn(
f"Sum of values for period '{period_name}' is near zero, "
"but individual values are not all zero. "
"Donut segments may not be visible or meaningful."
)
# If all plot_data_period are zero or negative, pie chart is problematic
if np.all(plot_data_period <= eps) and plot_data_period.size > 0 :
if verbose > 0:
warnings.warn(
f"All values for period '{period_name}' are zero or negative. "
"Drawing empty/minimal donut ring."
)
# Draw a simple ring to indicate the period without segments
ring = plt.Circle((0,0), current_outer_radius - donut_width/2,
width=donut_width, color='lightgray', fill=False,
linestyle='--', ec='gray')
ax.add_artist(ring)
# Add period label to the ring
angle_for_label = startangle + \
(counterclock * 2 -1) * (i * 360/num_periods) # Distribute labels
text_x = (current_outer_radius - donut_width/2) * \
np.cos(np.deg2rad(angle_for_label))
text_y = (current_outer_radius - donut_width/2) * \
np.sin(np.deg2rad(angle_for_label))
ax.text(text_x, text_y, str(period_name), ha='center', va='center',
fontsize='small', color='dimGray')
elif plot_data_period.size > 0 : # Ensure there's data to plot
wedges, *texts_autotexts = ax.pie(
plot_data_period,
radius=current_outer_radius,
labels=None,
autopct=(lambda pct: segment_label_format.format(
name="",
value=(pct/100.)*sum_plot_data if sum_plot_data > eps else 0,
percent=pct)
) if show_segment_labels and sum_plot_data > eps else None,
startangle=startangle,
counterclock=counterclock,
colors=period_colors,
explode=period_explodes,
wedgeprops={**base_wedgeprops, 'width': donut_width} # type: ignore
)
if i == 0 and show_overall_legend and plot_data_period.size > 0:
# Create legend items based on the *original* full set of quantiles
# and their assigned colors, to ensure consistency.
for q_idx, q_level_val in enumerate(q_levels_np):
# Use original quantile_cols if available and matches, else q=...
leg_name = quantile_cols[q_idx] if \
len(quantile_cols) == len(q_levels_np) else f"q={q_level_val:.2f}"
legend_items_overall.append(
(plt.Rectangle((0,0),1,1, facecolor=plot_colors[q_idx]), leg_name)
)
if show_center_text and plot_data_period.size > 0:
# For the innermost donut, show overall average for that period
avg_value_period = np.nanmean(values_for_period)
if not np.isnan(avg_value_period):
if i == num_periods -1 :
center_text_val = center_text_format.format(value=avg_value_period)
if plot_mode == 'avg_quantiles' and dt_col is None: # Single overall donut
center_text_val = "Avg. Quantile\nValues"
elif plot_mode == 'avg_quantiles':
center_text_val = f"{period_name}\nAvg. Values"
ax.text(0, 0, center_text_val,
ha='center', va='center', fontsize='medium',
fontweight='bold',
path_effects=kwargs.get('center_text_path_effects', None))
elif verbose > 1 and dt_col is not None:
print(
f"Period '{period_name}' Avg Score/Value: {avg_value_period:.3f}"
)
current_outer_radius -= (donut_width + donut_ring_spacing)
# --- 7. Final Touches (Legend, Grid) ---
if show_overall_legend and legend_items_overall:
# Remove duplicate legend items if colors/names were cycled
unique_legend_items = []
seen_labels = set()
for handle, label in legend_items_overall:
if label not in seen_labels:
unique_legend_items.append((handle, label))
seen_labels.add(label)
if unique_legend_items:
handles, labels = zip(*unique_legend_items)
ax.legend(handles, labels, title=legend_title,
loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor)
if show_grid:
if verbose > 0:
warnings.warn(
"'show_grid' for donut chart might not be conventional."
)
ax.grid(**(grid_props or {'linestyle': ':', 'alpha': 0.3}))
else:
ax.grid(False)
return ax