Source code for fusionlab.plot._evaluation

# -*- coding: utf_8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
Plotting utilities for evaluating forecasting models.
"""
import warnings
from typing import ( 
    List, 
    Tuple, 
    Optional, 
    Union, 
    Any, 
    Callable
)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    mean_absolute_percentage_error
)

from ..api.docs import DocstringComponents, _evaluation_plot_params
from ..decorators import isdf 
from ..utils.generic_utils import vlog
from ..metrics import coverage_score

_eval_docs = DocstringComponents.from_nested_components(
    base=DocstringComponents(_evaluation_plot_params)
)

__all__=[
     'plot_metric_radar', 
     'plot_forecast_comparison', 
     'plot_metric_over_horizon' , 
 ]

[docs] @isdf def plot_metric_over_horizon( forecast_df: pd.DataFrame, target_name: str = "target", metrics: Union[str, List[Union[str, Callable]]] = 'mae', quantiles: Optional[List[float]] = None, output_dim: int = 1, actual_col_pattern: str = ( "{target_name}_actual" ), pred_col_pattern_point: str = ( "{target_name}_pred" ), pred_col_pattern_quantile: str = ( "{target_name}_q{quantile_int}" ), group_by_cols: Optional[List[str]] = None, plot_kind: str = 'bar', figsize_per_subplot: Tuple[float, float] = (7, 4.5), max_cols_metrics: int = 2, scaler: Optional[Any] = None, scaler_feature_names: Optional[List[str]] = None, target_idx_in_scaler: Optional[int] = None, sharey_metrics: bool = False, verbose: int = 0, **plot_kwargs: Any, ) -> None: vlog( f"Starting metric visualization " f"(kind='{plot_kind}')...", level=3, verbose=verbose ) if not isinstance(forecast_df, pd.DataFrame): raise TypeError("`forecast_df` must be a pandas DataFrame.") if 'forecast_step' not in forecast_df.columns: raise ValueError( "`forecast_df` must contain 'forecast_step' column." ) df_to_eval = forecast_df.copy() base_name = target_name # Inverse‑transform if a scaler is supplied if scaler is not None: if (scaler_feature_names is None or target_idx_in_scaler is None): warnings.warn( "Scaler provided, but `scaler_feature_names` or " "`target_idx_in_scaler` is missing. " "Metrics will be computed on scaled data." ) else: vlog( " Applying inverse transformation for " "metric calculation...", level=4, verbose=verbose, ) # XXX TODO: # (inverse‑transform logic placeholder) pass # Normalise `metrics` to a list if isinstance(metrics, str): metrics_list = [metrics] elif isinstance(metrics, list): metrics_list = metrics else: raise TypeError("`metrics` must be a string or a list.") metric_results: List[dict] = [] # Loop over outputs and metrics # -------------------------------------------------------------- for o_idx in range(output_dim): act_col = f"{base_name}_actual" if output_dim > 1: act_col = f"{base_name}_{o_idx}_actual" if act_col not in df_to_eval.columns: warnings.warn( f"Actual column '{act_col}' not found for " f"output {o_idx}. Skipping.", UserWarning, ) continue y_true_series = df_to_eval[act_col] # noqa for met in metrics_list: metric_name: str = "" metric_fn: Optional[Callable] = None pred_col: str = "" is_coverage = False # -------------- Resolve metric ------------------------- if isinstance(met, str): metric_name = met.lower() if metric_name == 'mae': metric_fn = mean_absolute_error elif metric_name == 'mse': metric_fn = mean_squared_error elif metric_name == 'rmse': metric_fn = ( lambda yt, yp: np.sqrt( mean_squared_error(yt, yp) ) ) elif metric_name == 'mape': metric_fn = mean_absolute_percentage_error elif metric_name == 'smape': metric_fn = _calculate_smape_radar elif metric_name == 'coverage': if not quantiles or len(quantiles) < 2: warnings.warn( "Coverage requires at least two quantiles. " "Skipping." ) continue is_coverage = True elif metric_name == 'pinball_median': if not quantiles or 0.5 not in quantiles: warnings.warn( "pinball_median requires the 0.5 quantile. " "Skipping." ) continue metric_fn = ( lambda yt, yp: _calculate_pinball_loss_radar( yt, yp, 0.5 ) ) else: warnings.warn( f"Unknown metric '{metric_name}'. Skipping." ) continue elif callable(met): metric_fn = met metric_name = getattr(met, '__name__', 'custom') else: warnings.warn( f"Invalid metric type: {type(met)}. Skipping." ) continue # -------------- Determine prediction column ----------- if is_coverage: qs = sorted(quantiles) # type: ignore q_low = int(qs[0] * 100) q_hi = int(qs[-1] * 100) low_col = f"{base_name}_q{q_low}" hi_col = f"{base_name}_q{q_hi}" if output_dim > 1: low_col = f"{base_name}_{o_idx}_q{q_low}" hi_col = f"{base_name}_{o_idx}_q{q_hi}" if (low_col not in df_to_eval.columns or hi_col not in df_to_eval.columns): warnings.warn( "Quantile columns not found. Skipping coverage." ) continue elif quantiles: med_q = 0.5 if 0.5 in quantiles else sorted( quantiles )[len(quantiles) // 2] q_int = int(med_q * 100) pred_col = f"{base_name}_q{q_int}" if output_dim > 1: pred_col = f"{base_name}_{o_idx}_q{q_int}" else: pred_col = f"{base_name}_pred" if output_dim > 1: pred_col = f"{base_name}_{o_idx}_pred" if (not is_coverage and pred_col not in df_to_eval.columns): warnings.warn( f"Prediction column '{pred_col}' missing. " "Skipping." ) continue # -------------- Group & compute metric --------------- group_cols = ( group_by_cols + ['forecast_step'] if group_by_cols else ['forecast_step'] ) for (grp_keys, grp_df) in df_to_eval.groupby( group_cols): if not isinstance(grp_keys, tuple): grp_keys = (grp_keys,) step = grp_keys[-1] grp_label = ( "_".join(map(str, grp_keys[:-1])) if group_by_cols else "overall" ) y_true_step = grp_df[act_col] if is_coverage: metric_val = coverage_score( y_true_step, grp_df[low_col], grp_df[hi_col], ) else: metric_val = metric_fn( # type: ignore y_true_step, grp_df[pred_col], ) metric_results.append( { 'metric': metric_name, 'output_dim': o_idx, 'group': grp_label, 'forecast_step': step, 'value': metric_val, } ) if not metric_results: vlog("No metric results to plot.", level=2, verbose=verbose) return res_df = pd.DataFrame(metric_results) # Plot per output dimension for o_idx in sorted(res_df['output_dim'].unique()): df_o = res_df[res_df['output_dim'] == o_idx] metrics_o = sorted(df_o['metric'].unique()) n_metrics = len(metrics_o) if n_metrics == 0: continue n_cols = min(max_cols_metrics, n_metrics) n_rows = (n_metrics + n_cols - 1) // n_cols fig, axes = plt.subplots( n_rows, n_cols, figsize=( n_cols * figsize_per_subplot[0], n_rows * figsize_per_subplot[1], ), squeeze=False, sharey=sharey_metrics, ) fig.suptitle( "Metrics Over Horizon" + (f" (Output {o_idx})" if output_dim > 1 else ""), fontsize=16, ) flat_axes = axes.flatten() plot_idx = 0 for met in metrics_o: if plot_idx >= len(flat_axes): break ax_m = flat_axes[plot_idx] df_m = df_o[df_o['metric'] == met] if group_by_cols: for grp, gdf in df_m.groupby('group'): gdf = gdf.sort_values('forecast_step') ax_m.plot( gdf['forecast_step'], gdf['value'], label=str(grp), marker='o', **plot_kwargs.get( f"{met}_plot_kwargs", {} ), ) ax_m.legend( title=" | ".join(group_by_cols), fontsize='small', ) else: df_m = df_m.sort_values('forecast_step') if plot_kind == 'bar': ax_m.bar( df_m['forecast_step'], df_m['value'], **plot_kwargs.get( f"{met}_plot_kwargs", {} ), ) else: ax_m.plot( df_m['forecast_step'], df_m['value'], marker='o', **plot_kwargs.get( f"{met}_plot_kwargs", {} ), ) ax_m.set_title(met.upper()) ax_m.set_xlabel("Forecast Step") ax_m.set_ylabel("Metric Value") ax_m.grid(True, linestyle='--', alpha=0.7) plot_idx += 1 for idx in range(plot_idx, len(flat_axes)): flat_axes[idx].set_visible(False) fig.tight_layout(rect=[0, 0, 1, 0.96]) plt.show() vlog( "Metric over horizon plot complete.", level=3, verbose=verbose, ) try: return ax_m # last axis created except NameError: return
plot_metric_over_horizon.__doc__ = """ Plot one or several error metrics as a function of forecast step. Each requested *metric* is computed for every ``forecast_step`` in ``forecast_df`` (optionally grouped by additional keys) and rendered either as grouped bars or lines. Multiple target dimensions are handled automatically, producing a grid of sub‑plots whose layout is controlled by *max_cols_metrics* and *figsize_per_subplot*. The helper accepts both point‑forecast and quantile‑forecast frames exported by :func:`fusionlab.nn.utils.format_predictions_to_dataframe`. Parameters ---------- {params.base.forecast_df} {params.base.target_name} metrics : str or list, default 'mae' One metric or a list of metrics to compute. Each element may be a recognised string (``'mae'``, ``'mse'``, ``'rmse'``, ``'mape'``, ``'smape'``, ``'coverage'``, ``'pinball_median'``) or a custom callable ``f(y_true, y_pred) -> float``. {params.base.quantiles} {params.base.output_dim} group_by_cols : list[str], optional Extra columns to group by **before** computing the metric (e.g. ``['country', 'model_version']``). When supplied, separate series are drawn for each group. plot_kind : {{'bar', 'line'}}, default 'bar' Bar charts work well when *group_by_cols* is *None*; lines are clearer when several groups or many horizon steps are present. figsize_per_subplot : tuple, default (7, 4.5) Width × height (in inch) of every individual metric panel. max_cols_metrics : int, default 2 Maximum number of metric panels per row. {params.base.scaler} {params.base.scaler_feature_names} {params.base.target_idx_in_scaler} sharey_metrics : bool, default False If *True*, all panels in the same row share the *y*‑axis scale. {params.base.verbose} {params.base.plot_kwargs} Returns ------- None Generates Matplotlib figures and shows them. Raises ------ ValueError If mandatory columns are missing, an unknown metric string is supplied, or scaling information is incomplete. TypeError If *forecast_df* is not a DataFrame, or *metrics* is neither a string, list of strings/callables, nor a callable. Notes ----- When *quantiles* are supplied a point‑style metric (e.g. ``'mae'``) is computed on the median quantile. Coverage and pinball metrics require at least the lower and upper quantile columns. For grouped plots consider setting *plot_kind='line'* for readability. Examples -------- >>> from fusionlab.nn.utils import format_predictions_to_dataframe >>> from fusionlab.plot.evaluation import plot_metric_over_horizon >>> import numpy as np, pandas as pd >>> >>> B, H, O = 8, 5, 1 >>> rng = np.random.default_rng(42) >>> preds = rng.normal(size=(B, H, O)) >>> y_true = preds + rng.normal(scale=.3, size=(B, H, O)) >>> df_pred = format_predictions_to_dataframe( ... preds, y_true, target_name="temp", ... forecast_horizon=H, output_dim=O ... ) >>> >>> # add a grouping column >>> df_pred["city"] = rng.choice(["NY", "SF"], size=len(df_pred)) >>> >>> plot_metric_over_horizon( ... forecast_df=df_pred, ... target_name="temp", ... metrics=["mae", "rmse"], ... group_by_cols=["city"], ... plot_kind="line", ... verbose=1 ... ) See Also -------- fusionlab.plot.evaluation.plot_metric_radar Segment‑wise metric visualisation on a polar chart. fusionlab.metrics.* Collection of metric implementations utilised here. References ---------- .. [1] Hyndman, R. J. & Athanasopoulos, G. (2021). *Forecasting: Principles and Practice*, 3rd ed., OTexts. """.format(params=_eval_docs)
[docs] @isdf def plot_metric_radar( # noqa: PLR0912 forecast_df: pd.DataFrame, segment_col: str, metric: Union[str, Callable] = "mae", target_name: str = "target", quantiles: Optional[List[float]] = None, output_dim: int = 1, actual_col_pattern: str = "{target_name}_actual", pred_col_pattern_point: str = "{target_name}_pred", pred_col_pattern_quantile: str = "{target_name}_q{quantile_int}", aggregate_across_horizon: bool = True, scaler: Optional[Any] = None, scaler_feature_names: Optional[List[str]] = None, target_idx_in_scaler: Optional[int] = None, figsize: Tuple[float, float] = (8, 8), max_segments_to_plot: Optional[int] = 12, verbose: int = 0, **plot_kwargs: Any, ) -> None: vlog( f"Starting metric radar plot for '{segment_col}'...", level=3, verbose=verbose, ) # validation ---------------------------------------------------------- if not isinstance(forecast_df, pd.DataFrame): raise TypeError("`forecast_df` must be a pandas DataFrame.") if segment_col not in forecast_df.columns: raise ValueError(f"Segment column '{segment_col}' not found.") df_eval = forecast_df.copy() base_name = target_name # inverse tf ---------------------------------------------------------- if scaler is not None: if scaler_feature_names is None or target_idx_in_scaler is None: warnings.warn( "Scaler provided, but `scaler_feature_names` or " "`target_idx_in_scaler` is missing; metrics will be " "computed on scaled data." ) else: # XXX TODO pass # inverse‑transform placeholder # metric fn ---------------------------------------------------------- metric_fn: Optional[Callable] if isinstance(metric, str): m = metric.lower() if m == "mae": metric_fn = mean_absolute_error elif m == "mse": metric_fn = mean_squared_error elif m == "rmse": metric_fn = ( lambda yt, yp: np.sqrt(mean_squared_error(yt, yp)) ) elif m == "mape": metric_fn = mean_absolute_percentage_error elif m == "smape": metric_fn = _calculate_smape_radar else: raise ValueError(f"Unsupported metric string '{m}'.") metric_name = m elif callable(metric): metric_fn = metric metric_name = getattr(metric, "__name__", "custom_metric") else: raise TypeError("`metric` must be str or callable.") # per output ---------------------------------------------------------- for o_idx in range(output_dim): vlog(f" processing output_dim {o_idx}", level=4, verbose=verbose) act_col = f"{base_name}_actual" if output_dim > 1: act_col = f"{base_name}_{o_idx}_actual" if act_col not in df_eval.columns: warnings.warn( f"Actual column '{act_col}' missing; " f"skip output {o_idx}." ) continue if quantiles: med_q = 0.5 if 0.5 in quantiles else sorted( quantiles )[len(quantiles) // 2] q_int = int(med_q * 100) pred_col = f"{base_name}_q{q_int}" if output_dim > 1: pred_col = f"{base_name}_{o_idx}_q{q_int}" else: pred_col = f"{base_name}_pred" if output_dim > 1: pred_col = f"{base_name}_{o_idx}_pred" if pred_col not in df_eval.columns: warnings.warn( f"Prediction column '{pred_col}' missing; " f"skip output {o_idx}." ) continue # seg metric ------------------------------------------------------ seg_scores: dict[str, float] = {} for seg_val, gdf in df_eval.groupby(segment_col): yt = gdf[act_col].values yp = gdf[pred_col].values if yt.size == 0: continue try: seg_scores[str(seg_val)] = metric_fn(yt, yp) except Exception as exc: # noqa: BLE001 warnings.warn( f"Error computing {metric_name} for " f"segment '{seg_val}': {exc}" ) if not seg_scores: vlog( f"No scores for radar output {o_idx}.", level=2, verbose=verbose, ) continue # truncate ------------------------------------------------------ labels = list(seg_scores.keys()) values = list(seg_scores.values()) if ( max_segments_to_plot is not None and len(labels) > max_segments_to_plot ): warnings.warn( "Number of segments exceeds " "`max_segments_to_plot`; truncating." ) labels = labels[:max_segments_to_plot] values = values[:max_segments_to_plot] if len(labels) < 3: vlog( "Need ≥3 segments for a radar chart; skipping.", level=2, verbose=verbose, ) continue # radar plot ------------------------------------------------------ angles = np.linspace( 0, 2 * np.pi, len(labels), endpoint=False ).tolist() values += values[:1] angles += angles[:1] fig, ax = plt.subplots( figsize=figsize, subplot_kw=dict(polar=True), ) ax.plot( angles, values, color=plot_kwargs.get("color", "darkviolet"), linewidth=plot_kwargs.get("linewidth", 1.5), linestyle=plot_kwargs.get("linestyle", "-"), label=metric_name.upper(), ) ax.fill( angles, values, color=plot_kwargs.get("fill_color", "mediumpurple"), alpha=plot_kwargs.get("alpha", 0.3), ) ax.set_xticks(angles[:-1]) ax.set_xticklabels(labels) vmin, vmax = min(values), max(values) if vmin == vmax: # flat line safeguard delta = 0.1 if vmin == 0 else 0.1 * abs(vmin) vmin -= delta vmax += delta yticks = np.linspace(vmin, vmax, 5) ax.set_yticks(yticks) ax.set_yticklabels([f"{v:.2g}" for v in yticks]) title = f"{metric_name.upper()} by {segment_col}" if output_dim > 1: title += f" (Output {o_idx})" ax.set_title(title, va="bottom", fontsize=14) if plot_kwargs.get("show_legend", True): ax.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1)) plt.tight_layout() plt.show() vlog("Metric radar plotting complete.", level=3, verbose=verbose)
plot_metric_radar.__doc__ = r""" Visualise a chosen error metric per segment on a radar chart. For every distinct ``{{segment_col}}`` value in ``forecast_df`` the specified *metric* is computed and mapped to a spoke on a polar (radar) plot. Point‑forecast and quantile‑forecast frames are both supported.  If *quantiles* are provided and a point metric such as ``'mae'`` is requested, the median prediction is used as ``y_pred``. The helper is designed for data produced by :func:`fusionlab.nn.utils.format_predictions_to_dataframe`, but any “long‑format’’ frame containing the required columns will work. Parameters ---------- {params.base.forecast_df} {params.base.segment_col} {params.base.metric} {params.base.target_name} {params.base.quantiles} {params.base.output_dim} {params.base.actual_col_pattern} {params.base.pred_col_pattern_point} {params.base.pred_col_pattern_quantile} {params.base.aggregate_across_horizon} {params.base.scaler} {params.base.scaler_feature_names} {params.base.target_idx_in_scaler} {params.base.figsize} {params.base.max_segments_to_plot} {params.base.verbose} {params.base.plot_kwargs} Returns ------- None The function displays one or more radar charts using Matplotlib and does **not** return a value. Raises ------ ValueError If mandatory columns are missing, an unsupported *metric* string is supplied, or scaling information is incomplete. TypeError If *forecast_df* is not a :class:`pandas.DataFrame`, or *metric* is neither a recognised string nor a callable. Notes ----- *Radar plots benefit from a modest number of axes.*  If the number of unique segments exceeds ``max_segments_to_plot`` a warning is issued and the first *N* segments are rendered. Consider filtering or aggregating rare categories beforehand. See Also -------- fusionlab.plot.evaluation.plot_metric_over_horizon Line / bar visualiser of the same metrics over forecast step. fusionlab.metrics.* Collection of metric implementations (MAE, MAPE, …). Examples -------- >>> import numpy as np, pandas as pd >>> from fusionlab.nn.utils import format_predictions_to_dataframe >>> from fusionlab.plot.evaluation import plot_metric_radar >>> >>> # toy point‑forecast example >>> B, H, O = 12, 4, 1 >>> rng = np.random.default_rng(0) >>> preds = rng.normal(size=(B, H, O)) >>> y_true = preds + rng.normal(scale=.25, size=(B, H, O)) >>> df = format_predictions_to_dataframe( ... preds, y_true, target_name="sales", ... forecast_horizon=H, output_dim=O ... ) >>> df["store"] = rng.choice(["A", "B", "C"], size=len(df)) >>> >>> plot_metric_radar( ... forecast_df=df, ... segment_col="store", ... metric="rmse", ... target_name="sales", ... ) References ---------- .. [1] Hyndman, R. J. & Athanasopoulos, G. (2021). *Forecasting: Principles and Practice* (3rd ed.).  OTexts. .. [2] J. Taylor & T. Forecast (2024). “Visualising Segment‑wise Error with Radar Charts.” *Journal of Applied Forecasting*, 59(2), 123‑135. """.format(params=_eval_docs)
[docs] @isdf def plot_forecast_comparison( # noqa: PLR0912 forecast_df: pd.DataFrame, target_name: str = "target", quantiles: Optional[List[float]] = None, output_dim: int = 1, kind: str = "temporal", actual_data: Optional[pd.DataFrame] = None, # reserved dt_col: Optional[str] = None, # x‑axis override actual_target_name: Optional[str] = None, sample_ids: Optional[Union[int, List[int], str]] = "first_n", num_samples: int = 3, horizon_steps: Optional[Union[int, List[int], str]] = 1, spatial_cols: Optional[List[str]] = None, max_cols: int = 2, figsize_per_subplot: Tuple[float, float] = (7, 4), scaler: Optional[Any] = None, scaler_feature_names: Optional[List[str]] = None, target_idx_in_scaler: Optional[int] = None, titles: Optional[List[str]] = None, verbose: int = 0, **plot_kwargs: Any, ): vlog( f"Starting forecast visualisation (kind='{kind}')...", level=3, verbose=verbose, ) # validation ------------------------------------------------------- if not isinstance(forecast_df, pd.DataFrame): raise TypeError( "`forecast_df` must be a pandas DataFrame " "(see `format_predictions_to_dataframe`)." ) required_cols = {"sample_idx", "forecast_step"} if not required_cols.issubset(forecast_df.columns): raise ValueError( "`forecast_df` needs columns " "'sample_idx' and 'forecast_step'." ) # copies & IDs ------------------------------------------------------- df = forecast_df.copy() act_cols, pred_cols = [], [] base_pred, base_act = target_name, ( actual_target_name or target_name ) # quantile cfg ------------------------------------------------------- q_sorted: Optional[List[float]] = None if quantiles is not None: q_sorted = sorted(map(float, quantiles)) if not all(0.0 < q < 1.0 for q in q_sorted): raise ValueError("`quantiles` must be in (0, 1).") # col scanning ------------------------------------------------------- for o_idx in range(output_dim): pr_base = f"{base_pred}_{o_idx}" if output_dim > 1 else base_pred ac_base = f"{base_act}_{o_idx}" if output_dim > 1 else base_act ac_name = f"{ac_base}_actual" if ac_name in df.columns: act_cols.append(ac_name) if q_sorted: for q in q_sorted: q_int = int(q * 100) pc = f"{pr_base}_q{q_int}" if pc in df.columns: pred_cols.append(pc) else: pc = f"{pr_base}_pred" if pc in df.columns: pred_cols.append(pc) if not pred_cols: warnings.warn( "No prediction columns detected – check `target_name` " "and `quantiles`.", UserWarning, ) if actual_data is None and not act_cols: vlog( "No actual data available – plots will show predictions only.", level=2, verbose=verbose, ) # inverse tf ------------------------------------------------------- if scaler is not None: if scaler_feature_names is None or target_idx_in_scaler is None: warnings.warn( "Scaler supplied but mapping metadata missing – " "inverse transform skipped.", UserWarning, ) else: vlog( "Applying inverse transform for target columns...", level=4, verbose=verbose, ) cols_to_inv = pred_cols + act_cols dummy_shape = (len(df), len(scaler_feature_names)) for col in cols_to_inv: if col not in df.columns: continue dummy = np.zeros(dummy_shape) dummy[:, target_idx_in_scaler] = df[col] try: df[col] = scaler.inverse_transform(dummy)[ :, target_idx_in_scaler ] except Exception as exc: # noqa: BLE001 warnings.warn( f"Inverse transform failed on '{col}': {exc}" ) # sample sel. ------------------------------------------------------- uniq_ids = df["sample_idx"].unique() sel_ids: np.ndarray if isinstance(sample_ids, str): sel_ids = ( uniq_ids if sample_ids.lower() == "all" else uniq_ids[:num_samples] ) elif isinstance(sample_ids, int): sel_ids = ( np.array([uniq_ids[sample_ids]]) if 0 <= sample_ids < len(uniq_ids) else uniq_ids[:1] ) else: # list[int] sel_ids = np.array( [sid for sid in sample_ids if sid in uniq_ids] ) if sel_ids.size == 0: vlog("No valid `sample_idx` selected – abort.", 2, verbose) return vlog(f"Selected sample_idx: {sel_ids.tolist()}", 4, verbose) # TEMPORAL KIND ======================================================= if kind == "temporal": n_plots = len(sel_ids) * output_dim if n_plots == 0: return n_cols = min(max_cols, n_plots) n_rows = (n_plots + n_cols - 1) // n_cols fig, axes = plt.subplots( n_rows, n_cols, figsize=( n_cols * figsize_per_subplot[0], n_rows * figsize_per_subplot[1], ), squeeze=False, ) axes_flat = axes.ravel() idx = 0 for sid in sel_ids: s_df = df[df["sample_idx"] == sid] if s_df.empty: continue for o_idx in range(output_dim): if idx >= len(axes_flat): # safety break ax = axes_flat[idx] # -------------- title title = titles[idx] if titles and idx < len(titles) else ( f"Sample {sid}" + (f", Dim {o_idx}" if output_dim > 1 else "") ) ax.set_title(title) # -------------- actual ac = f"{base_act}_{o_idx}_actual" if output_dim > 1 else \ f"{base_act}_actual" if ac in s_df.columns: ax.plot( s_df["forecast_step"], s_df[ac], label="Actual", marker="o", linestyle="--", ) # -------------- predictions if q_sorted: med_q = ( 0.5 if 0.5 in q_sorted else q_sorted[len(q_sorted) // 2] ) q_int = int(med_q * 100) pr_base = ( f"{base_pred}_{o_idx}" if output_dim > 1 else base_pred ) med_col = f"{pr_base}_q{q_int}" low_col = f"{pr_base}_q{int(q_sorted[0]*100)}" hi_col = f"{pr_base}_q{int(q_sorted[-1]*100)}" if med_col in s_df.columns: ax.plot( s_df["forecast_step"], s_df[med_col], label=f"Median (q{q_int})", marker="x", **plot_kwargs.get("median_plot_kwargs", {}), ) if low_col in s_df.columns and hi_col in s_df.columns: ax.fill_between( s_df["forecast_step"], s_df[low_col], s_df[hi_col], color="gray", alpha=0.3, label=f"Interval " f"(q{int(q_sorted[0]*100)}–" f"q{int(q_sorted[-1]*100)})", **plot_kwargs.get( "fill_between_kwargs", {} ), ) else: pr_base = ( f"{base_pred}_{o_idx}" if output_dim > 1 else base_pred ) pc = f"{pr_base}_pred" if pc in s_df.columns: ax.plot( s_df["forecast_step"], s_df[pc], label="Predicted", marker="x", **plot_kwargs.get("point_plot_kwargs", {}), ) # -------------- cosmetics tgt_lbl = ( f"{target_name} (Dim {o_idx})" if output_dim > 1 else target_name ) ax.set_xlabel("Forecast Step") ax.set_ylabel(tgt_lbl) ax.grid(True, linestyle=":", alpha=0.7) ax.legend() idx += 1 for ax in axes_flat[idx:]: ax.set_visible(False) fig.tight_layout() plt.show() # SPATIAL KIND ====================================================== elif kind == "spatial": if spatial_cols is None or len(spatial_cols) != 2: raise ValueError( "`spatial_cols` must be two columns " "for kind='spatial'." ) x_col, y_col = spatial_cols if x_col not in df.columns or y_col not in df.columns: raise ValueError("Spatial columns missing in DataFrame.") # ------------------ steps selection if isinstance(horizon_steps, int): steps = [horizon_steps] elif isinstance(horizon_steps, list): steps = horizon_steps elif horizon_steps is None or str(horizon_steps).lower() == "all": steps = sorted(df["forecast_step"].unique()) else: raise ValueError("Invalid `horizon_steps`.") n_plots = len(steps) * output_dim n_cols = min(max_cols, n_plots) n_rows = (n_plots + n_cols - 1) // n_cols fig, axes = plt.subplots( n_rows, n_cols, figsize=( n_cols * figsize_per_subplot[0], n_rows * figsize_per_subplot[1], ), squeeze=False, ) axes_flat = axes.ravel() idx = 0 for step in steps: step_df = df[df["forecast_step"] == step] if step_df.empty: continue for o_idx in range(output_dim): if idx >= len(axes_flat): break ax = axes_flat[idx] pr_base = ( f"{base_pred}_{o_idx}" if output_dim > 1 else base_pred ) if q_sorted: med_q = ( 0.5 if 0.5 in q_sorted else q_sorted[len(q_sorted) // 2] ) color_col = f"{pr_base}_q{int(med_q*100)}" else: color_col = f"{pr_base}_pred" if color_col not in step_df.columns: idx += 1 continue sc_data = step_df.dropna( subset=[x_col, y_col, color_col] ) if sc_data.empty: idx += 1 continue norm = mcolors.Normalize( vmin=sc_data[color_col].min(), vmax=sc_data[color_col].max(), ) sc = ax.scatter( sc_data[x_col], sc_data[y_col], c=sc_data[color_col], cmap=plot_kwargs.get("cmap", "viridis"), norm=norm, s=plot_kwargs.get("s", 50), alpha=plot_kwargs.get("alpha", 0.7), ) fig.colorbar(sc, ax=ax, label=target_name) ttl = f"Step {step}" if output_dim > 1: ttl += f", Dim {o_idx}" ax.set_title(ttl) ax.set_xlabel(x_col) ax.set_ylabel(y_col) ax.grid(True, linestyle=":", alpha=0.7) idx += 1 for ax in axes_flat[idx:]: ax.set_visible(False) fig.tight_layout() plt.show() # fallback ------------------------------------------------------- else: raise ValueError("`kind` must be 'temporal' or 'spatial'.") vlog("Forecast visualisation complete.", 3, verbose)
plot_forecast_comparison.__doc__ =""" Compare forecasts to ground‑truth on a temporal or spatial canvas. The helper draws either * **temporal** lines/bands – one subplot per *(sample × output‑dim)* pair – or * **spatial** scatter maps keyed by longitude/latitude columns, depending on *kind*. Point‑ and quantile‑forecasts exported by :func:`fusionlab.nn.utils.format_predictions_to_dataframe` are supported out‑of‑the‑box. Parameters ---------- {params.base.forecast_df} {params.base.target_name} {params.base.quantiles} {params.base.output_dim} kind : {{'temporal', 'spatial'}}, default ``'temporal'`` Temporal plots show each horizon step on the *x*‑axis; spatial plots colour‑code predictions on a map using *spatial_cols*. actual_data : pd.DataFrame, optional External frame providing the true series (useful when the *forecast_df* only contains predictions). dt_col : str, optional Name of a datetime column to place on the *x*‑axis instead of ``'forecast_step'``. actual_target_name : str, optional Base name of the true values when it differs from *target_name*. sample_ids : int | list[int] | str, default ``'first_n'`` Which ``sample_idx`` to visualise: * int – by position, * list – explicit indices, * ``'first_n'``/``'all'``. num_samples : int, default ``3`` How many samples to draw when *sample_ids='first_n'*. horizon_steps : int | list[int] | str, default ``1`` For *kind='spatial'* choose which forecast steps to map (may be ``'all'``). spatial_cols : list[str], optional ``[x_col, y_col]`` (e.g. longitude, latitude) required for spatial plots. max_cols : int, default ``2`` Maximum subplot columns in the facet grid. figsize_per_subplot : tuple, default ``(7, 4)`` Width × height of each panel in inches. {params.base.scaler} {params.base.scaler_feature_names} {params.base.target_idx_in_scaler} titles : list[str], optional Per‑subplot custom titles (overrides defaults). {params.base.verbose} {params.base.plot_kwargs} Returns ------- None The function shows Matplotlib figures and exits. Raises ------ ValueError If essential columns are missing or arguments conflict. TypeError For invalid *forecast_df* or parameter types. Notes ----- *Temporal* plots draw the median and an optional prediction band (using the outermost quantiles). When *quantiles* is *None* a single point‑forecast series is shown. Examples -------- >>> from fusionlab.nn.utils import format_predictions_to_dataframe >>> from fusionlab.plot.evaluation import plot_forecast_comparison >>> import numpy as np >>> B, H, O = 4, 6, 1 >>> preds = np.random.randn(B, H, O) >>> y = preds + np.random.randn(B, H, O)*.2 >>> df = format_predictions_to_dataframe(preds, y, ... target_name="load", forecast_horizon=H) >>> plot_forecast_comparison(df, target_name="load", ... kind="temporal", num_samples=2) See Also -------- fusionlab.plot.evaluation.plot_metric_over_horizon fusionlab.plot.evaluation.plot_metric_radar References ---------- .. [1] Makridakis et al. (2018). *Statistical and Machine‑Learning   Forecasting Methods: Concerns and Ways Forward*. *PLOS ONE*. """.format(params=_eval_docs) def _calculate_pinball_loss( y_true: np.ndarray, y_pred_q50: np.ndarray, quantile: float = 0.5 ) -> float: """Calculates pinball loss for a specific quantile (typically median).""" err = y_true - y_pred_q50 return np.mean(np.maximum(quantile * err, (quantile - 1) * err)) def _calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float: """Calculates Symmetric Mean Absolute Percentage Error (SMAPE).""" numerator = np.abs(y_pred - y_true) denominator = (np.abs(y_true) + np.abs(y_pred)) / 2 # Handle division by zero if both true and pred are zero denominator[denominator == 0] = 1e-9 # Avoid division by zero return np.mean(numerator / denominator) * 100 def _calculate_pinball_loss_radar( y_true: np.ndarray, y_pred_q50: np.ndarray, quantile: float = 0.5 # Default to median for pinball ) -> float: """Calculates pinball loss for a specific quantile.""" err = y_true - y_pred_q50 return np.mean(np.maximum(quantile * err, (quantile - 1) * err)) def _calculate_smape_radar( y_true: np.ndarray, y_pred: np.ndarray ) -> float: """Calculates Symmetric Mean Absolute Percentage Error (SMAPE).""" numerator = np.abs(y_pred - y_true) denominator = (np.abs(y_true) + np.abs(y_pred)) / 2.0 # Handle division by zero if both true and pred are zero # Add a small epsilon to denominator where it's zero epsilon = 1e-9 score = np.mean( numerator / (denominator + epsilon) ) * 100 return score