Source code for fusionlab.plot.forecast

# -*- coding: utf‑8 -*-
# License: BSD‑3‑Clause
# Author: LKouadio <etanoyau@gmail.com>

"""
Convenience helpers for **quick‑look** visual inspection of model
forecasts.
"""

from __future__ import annotations

import os
import re
import logging 
import warnings
import datetime 
from typing import ( 
    Any, List, Optional, 
    Tuple, Union, Dict, 
    Callable, 
    Literal,  
    Mapping,
    Sequence,
)

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from matplotlib.cm import ScalarMappable
import numpy as np
import pandas as pd

from ..core.checks import ( 
    check_non_emptiness, 
    check_spatial_columns, 
    is_in_if
)
from ..core.handlers import columns_manager
from ..utils.forecast_utils import ( 
    format_forecast_dataframe, 
    get_value_prefixes, 
    get_value_prefixes_in, 
    detect_forecast_type, 
    get_step_names, 
) 
from ..utils.calibrate import (
    calibrate_forecasts, 
    calibrate_probability_forecast
    
)
from ..utils.generic_utils import ( 
    _coerce_dt_kw, 
    get_actual_column_name,
    vlog, save_figure, 
    normalize_model_inputs
)
from ..utils.validator import ( 
    assert_xy_in, is_frame, 
    validate_positive_integer 
)

__all__= [
    "forecast_view",
    "plot_forecast_by_step", 
    "plot_forecasts", 
    "visualize_forecasts", 
    "plot_reliability_diagram", 
    "plot_calibration_comparison"
    
 ]

@check_non_emptiness
def plot_calibration_comparison(
    *data: Union[pd.DataFrame, Mapping[str, pd.DataFrame]],
    quantiles: Optional[Sequence[float]] = None,
    q_prefix: Optional[str] = None,
    actual_col: Optional[str] = None,
    prob_col: Optional[str] = None,
    method: Literal["isotonic","logistic"] = "isotonic",
    out_prefix: str = "calib",
    grid_mode: Literal["unit","range"] = "unit",
    grid_size: int = 1001,
    group_by: Optional[str] = None,
    bins: int = 10,
    bin_strategy: Literal["uniform","quantile"] = "uniform",
    show_grid: bool = True,
    grid_props: Optional[Dict[str, Any]] = None,
    figsize: Optional[Tuple[float, float]] = None,
    savefig: Optional[str] = None,
    save_fmts: Union[str, List[str]] = ".png",
    verbose: int = 1,
    _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
) -> plt.Axes:
    """
    Plot raw vs calibrated reliability curves for one or more models.

    This function overlays the original ("raw") calibration curve
    and the post-processed ("calibrated") curve on the same axes.
    It supports both quantile-based forecasts and direct 
    probability forecasts.

    Parameters
    ----------
    *data : DataFrame or dict or list of DataFrames
        One or more forecast tables.  Each must contain either:
        - Quantile columns named "{q_prefix}_qXX" plus
          actuals in `actual_col`, or
        - A probability column `prob_col` plus binary 
          outcomes in `actual_col`.
        You may supply:
        - A single DataFrame
        - A dict mapping labels to DataFrames
        - A list/tuple of DataFrames
        - Multiple DataFrame args
    quantiles : sequence of float, optional
        Nominal quantile levels (e.g. [0.1,0.5,0.9]).  If set,
        `q_prefix` and `actual_col` must also be provided.
    q_prefix : str, optional
        Prefix used to identify quantile columns.  E.g. if
        q_prefix="subsidence", looks for "subsidence_q10", etc.
    actual_col : str, optional
        Name of the column containing true values.  For
        quantiles this is continuous; for probabilities it
        should be 0/1 flags.
    prob_col : str, optional
        Name of a direct probability forecast column (0–1).
        If set, `quantiles` is ignored and forecasts are
        binned into `bins` intervals.
    grid_mode : {'unit','range'}, default 'unit'
        Which domain to build the inversion grid over:
          - 'unit' -> np.linspace(0,1,grid_size)
          - 'range' -> np.linspace(min(raw), max(raw), grid_size)
    grid_size : int, default 1001
        Number of points in the inversion grid.
    group_by : str, optional
        If provided, calibrate separately per `df[group_by]`
        (e.g. 'forecast_step').
    bins : int, default 10
        Number of bins when using `prob_col`.
    bin_strategy : {'uniform','quantile'}, default 'uniform'
        How to form probability bins: equal-width or 
        equal-population.
    show_grid : bool, default True
        Whether to display grid lines.
    grid_props : dict, optional
        Keyword args passed to `Axes.grid()`.  Defaults to
        {'linestyle':':','alpha':0.7}.
    figsize : tuple, optional
        Matplotlib figure size in inches.
    savefig : str, optional
        File path (without extension) to save the figure.
    save_fmts : str or list of str, default '.png'
        File extension(s) used by `save_figure`.
    verbose : int, default 1
        Verbosity level for internal logging via `vlog`.
    _logger : Logger or callable, optional
        Logger to receive messages.  If None, module logger
        is used.

    Returns
    -------
    matplotlib.axes.Axes
        The Axes containing the raw and calibrated curves.

    Notes
    -----
    - Raw calibration shows nominal vs empirical coverage
      directly from model outputs.
    - Calibrated curves apply isotonic or logistic scaling
      to align empirical frequencies to nominal levels.
    - For quantiles, each q-level is treated as a binary
      classifier; we learn P(actual <= q_pred).
    - For probabilities, forecasts are binned before
      comparing mean forecast to observed frequency.

    See Also
    --------
    calibrate_quantile_forecasts : Post-process quantile
        forecasts via monotonic CDF inversion.
    calibrate_probability_forecast : Calibrate 0–1 forecasts
        with isotonic or logistic scaling.
    plot_reliability_diagram : Single-curve reliability plot.

    References
    ----------
    Bröcker, J. & Smith, L. A., 2007. Scoring Probabilistic
      Forecasts: The Importance of Being Proper. 
      Weather and Forecasting, 22(2), pp.382–388.
    Wilks, D. S., 2011. Statistical Methods in the 
      Atmospheric Sciences (3rd ed.). Academic Press.

    Examples
    --------
    >>> import pandas as pd
    >>> from fusionlab.plot.forecast import plot_calibration_comparison
    
    # 1) Single‐model quantile calibration (default isotonic, unit grid)
    >>> df = pd.DataFrame({
    ...     'subsidence_q10': [1, 2, 3, 4],
    ...     'subsidence_q50': [2, 3, 4, 5],
    ...     'subsidence_q90': [3, 4, 5, 6],
    ...     'subsidence_actual': [1.5, 3.5, 4.2, 5.8]
    ... })
    >>> plot_calibration_comparison(
    ...     df,
    ...     quantiles=[0.1, 0.5, 0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual'
    ... )
    
    # 2) Single‐model probability calibration (20 quantile bins)
    >>> pdf = pd.DataFrame({
    ...     'p_event': [0.1, 0.4, 0.8, 0.9, 0.3],
    ...     'event_flag': [0, 1, 1, 1, 0]
    ... })
    >>> plot_calibration_comparison(
    ...     pdf,
    ...     prob_col='p_event',
    ...     actual_col='event_flag',
    ...     bins=20,
    ...     bin_strategy='quantile'
    ... )
    
    # 3) Compare two models via a dict of DataFrames
    >>> df2 = df.copy()  # pretend a second model
    >>> plot_calibration_comparison(
    ...     {'XTFT': df, 'PINN': df2},
    ...     quantiles=[0.1, 0.5, 0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual'
    ... )
    
    # 4) Multiple unnamed DataFrames as separate models
    >>> plot_calibration_comparison(
    ...     df, df2,
    ...     quantiles=[0.1, 0.5, 0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual'
    ... )
    
    # 5) Per‐horizon (step) calibration: different curves for step 1,2,3
    >>> # assume df_long has 'forecast_step' 1,2,3 for each sample_idx
    >>> df_long = pd.DataFrame({
    ...     'sample_idx': [0,0,0,1,1,1],
    ...     'forecast_step': [1,2,3,1,2,3],
    ...     'subsidence_q10': [ .1, .2, .3, .1, .2, .3],
    ...     'subsidence_q50': [ .5, .6, .7, .5, .6, .7],
    ...     'subsidence_q90': [ .9, 1.0, 1.1, .9, 1.0, 1.1],
    ...     'subsidence_actual': [ .2, .5, .8, .4, .7, 1.0]
    ... })
    >>> plot_calibration_comparison(
    ...     df_long,
    ...     quantiles=[0.1, 0.5, 0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual',
    ...     group_by='forecast_step'
    ... )
    
    # 6) Logistic Platt‐scaling & range grid
    >>> plot_calibration_comparison(
    ...     df,
    ...     quantiles=[0.1, 0.5, 0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual',
    ...     method='logistic',
    ...     grid_mode='range',
    ...     grid_size=500
    ... )
    
    # 7) List of DataFrames (treated like multiple models)
    >>> plot_calibration_comparison(
    ...     [pdf, pdf.copy()],
    ...     prob_col='p_event',
    ...     actual_col='event_flag'
    ... )

    """
    models = normalize_model_inputs(*data)

    if grid_props is None:
        grid_props = {"linestyle": ":", "alpha": 0.7}

    fig, ax = plt.subplots(figsize=figsize)
    ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect")

    for name, df in models.items():
        vlog(f"Processing model '{name}'", level=verbose,
             verbose=verbose, _logger=_logger)

        # Quantile-based calibration
        if quantiles and q_prefix and actual_col:
            df_calib = calibrate_forecasts(
                df, quantiles, q_prefix, actual_col,
                method=method,
                out_prefix=out_prefix,
                grid_mode=grid_mode,
                grid_size=grid_size,
                group_by=group_by
            )
            nom = list(quantiles)
            emp_raw = [
                np.mean(df[actual_col] <= df[f"{q_prefix}_q{int(q*100)}"])
                for q in quantiles
            ]
            emp_cal = [
                np.mean(df[actual_col] <= df_calib[
                    f"{out_prefix}_{q_prefix}_q{int(q*100)}"])
                for q in quantiles
            ]
            ax.plot(nom, emp_raw, marker="o", label=f"{name} raw")
            ax.plot(nom, emp_cal, marker="x", label=f"{name} calib")

        # Probability-based calibration
        elif prob_col and actual_col:
            df_calib = calibrate_probability_forecast(
                df, prob_col, actual_col, method=method
            )
            y_pred = df[prob_col].to_numpy()
            y_cal = df_calib[f"{prob_col}_calib"].to_numpy()
            y_true = df[actual_col].to_numpy()

            if bin_strategy == "uniform":
                edges = np.linspace(0, 1, bins + 1)
            else:
                edges = np.unique(np.quantile(
                    y_pred, np.linspace(0, 1, bins + 1)
                ))
            centers = (edges[:-1] + edges[1:]) / 2

            def _bin_freq(vals):
                idx = np.digitize(vals, edges, right=True) - 1
                return [
                    np.mean(y_true[idx == i]) if np.any(idx == i) else np.nan
                    for i in range(len(edges) - 1)
                ]

            emp_raw = _bin_freq(y_pred)
            emp_cal = _bin_freq(y_cal)

            ax.plot(centers, emp_raw, marker="o", label=f"{name} raw")
            ax.plot(centers, emp_cal, marker="x", label=f"{name} calib")

        else:
            raise ValueError(
                "Must provide either (quantiles+q_prefix+actual_col) "
                "or (prob_col+actual_col)"
            )

    if show_grid:
        ax.grid(True, **grid_props)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xlabel("Nominal probability")
    ax.set_ylabel("Empirical frequency")
    ax.set_title("Raw vs Calibrated Reliability")
    ax.legend()

    if savefig:
        save_figure(
            fig,
            savefile=savefig,
            save_fmts=save_fmts,
            dpi=300,
            bbox_inches="tight",
            verbose=verbose,
            _logger=_logger,
        )
        plt.close(fig)
    else:
        plt.show()

    return ax

@check_non_emptiness
def plot_reliability_diagram(
    *data: Union[pd.DataFrame, Mapping[str, pd.DataFrame]],
    quantiles: Optional[Sequence[float]] = None,
    q_prefix: Optional[str] = None,
    actual_col: Optional[str] = None,
    prob_col: Optional[str] = None,
    bins: int = 10,
    bin_strategy: Literal["uniform","quantile"] = "uniform",
    index_col: str = "sample_idx",
    step_col: str = "forecast_step",
    time_col: str = "coord_t",
    xlabel: str = "Nominal probability",
    ylabel: str = "Empirical frequency",
    title: str = "Reliability Diagram",
    dt_col: Optional[str] = None,
    show_grid: bool = True,
    grid_props: Optional[Dict[str, Any]] = None,
    figsize: Optional[Tuple[float, float]] = None,
    savefig: Optional[str] = None,
    save_fmts: Union[str, List[str]] = ".png",
    verbose: int = 1,
    _logger: Optional[Union[logging.Logger,
                           Callable[[str], None]]] = None
) -> plt.Axes:
    """
    Plot reliability (calibration) curves for quantile or probability
    forecasts across one or more models.

    Parameters
    ----------
    *data : DataFrame or dict or list of DataFrames
        One or more forecast tables.  Each table must contain either:
        - quantile columns named ``{q_prefix}_qXX`` plus
          ``actual_col`` for true values, or
        - a probability column ``prob_col`` (0–1) plus
          ``actual_col`` for binary events.
        You may pass:
        - A single DataFrame
        - A dict mapping model names to DataFrames
        - A list/tuple of DataFrames
        - Multiple DataFrame arguments
    quantiles : list of float, optional
        Nominal quantile levels (e.g. [0.1,0.5,0.9]).  If provided,
        ``q_prefix`` and ``actual_col`` must be set.  The curve
        plots q vs empirical coverage.
    q_prefix : str, optional
        Prefix of quantile columns (e.g. 'subsidence' to find
        'subsidence_q10', 'subsidence_q50', etc.).
    actual_col : str, optional
        Column name of true values (continuous for quantiles or
        binary for probabilities).
    prob_col : str, optional
        Name of a direct probability forecast (0–1).  If set,
        ``quantiles`` is ignored and reliability is computed by
        binning forecasts into ``bins`` groups.
    bins : int, default 10
        Number of bins when ``prob_col`` is used.
    bin_strategy : {'uniform','quantile'}, default 'uniform'
        How to choose probability bins:
        - 'uniform': equal-width in [0,1]
        - 'quantile': equal-population bins
    index_col : str, default 'sample_idx'
        Name of the sample identifier column (unused internally).
    step_col : str, default 'forecast_step'
        Name of the integer step column (unused unless dt_col
        is auto-inferred).
    time_col : str, default 'coord_t'
        Name of the datetime column (unused unless dt_col
        is auto-inferred).
    xlabel : str, default 'Nominal probability'
        X-axis label.
    ylabel : str, default 'Empirical frequency'
        Y-axis label.
    title : str, default 'Reliability Diagram'
        Plot title.
    dt_col : str, optional
        Alternate name for the datetime column.  Handled via
        ``_coerce_dt_kw`` to unify with ``time_col``.
    show_grid : bool, default True
        Whether to display the grid.
    grid_props : dict, optional
        Passed to ``Axes.grid()``.  Defaults to
        ``{'linestyle':':','alpha':0.7}``.
    figsize : (float,float), optional
        Figure size in inches.
    savefig : str, optional
        Path (no extension) to save the figure.  Uses
        ``save_fmts`` to determine extensions.
    save_fmts : str or list of str, default '.png'
        File format(s) for saving (e.g. ['.png','.pdf']).
    verbose : int, default 1
        Verbosity level for internal logging via ``vlog``.
    _logger : Logger or callable, optional
        Where to send info/warnings.  Defaults to the module
        logger if None.

    Returns
    -------
    matplotlib.axes.Axes
        The Axes object with the reliability curves.

    Notes
    -----
    - When using ``quantiles``, the model is evaluated at each
      q-level by computing the fraction of true values ≤ predicted
      q-quantile.
    - When using ``prob_col``, forecasts are binned and the average
      forecast vs empirical frequency is plotted.
    - Multiple models can be compared by passing a dict or list of
      DataFrames; each appears as a separate line.

    See Also
    --------
    compute_quantile_coverage : Calculate empirical coverage for
        quantile forecasts.
    pivot_forecast_dataframe : Pivot long-format forecast tables to
        wide format by step or date.
    plot_probability_calibration : For continuous-valued models,
        alternate calibration plot by grouping residuals.

    References
    ----------
    Bröcker, J., & Smith, L. A. (2007). Scoring Probabilistic
    Forecasts: The Importance of Being Proper. Weather and
    Forecasting, 22(2), 382–388.
    Wilks, D. S. (2011). Statistical Methods in the Atmospheric
    Sciences (3rd ed.). Academic Press.

    Examples
    --------
    >>> import pandas as pd
    >>> from fusionlab.plot.forecast import plot_reliability_diagram
    >>> # 1) Single-model quantiles
    >>> df = pd.DataFrame({
    ...     'subsidence_q10': [1,2,3,4],
    ...     'subsidence_q50': [2,3,4,5],
    ...     'subsidence_q90': [3,4,5,6],
    ...     'subsidence_actual': [1.5, 3.5, 4.2, 5.8]
    ... })
    >>> plot_reliability_diagram(
    ...     df,
    ...     quantiles=[0.1,0.5,0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual'
    ... )
    >>> # 2) Probability forecasts, 20 quantile bins
    >>> pdf = pd.DataFrame({
    ...     'p_event': [0.1,0.4,0.8,0.9,0.3],
    ...     'event_flag': [0,1,1,1,0]
    ... })
    >>> plot_reliability_diagram(
    ...     pdf,
    ...     prob_col='p_event',
    ...     actual_col='event_flag',
    ...     bins=20, bin_strategy='quantile'
    ... )
    >>> # 3) Compare two models
    >>> df1 = df.copy()
    >>> df2 = df.copy()
    >>> plot_reliability_diagram(
    ...     {'XTFT': df1, 'PINN': df2},
    ...     quantiles=[0.1,0.5,0.9],
    ...     q_prefix='subsidence',
    ...     actual_col='subsidence_actual',
    ...     figsize=(8,5)
    ... )

    """
    # canonicalise column name
    kw = _coerce_dt_kw(
        dt_col=dt_col, time_col=time_col, _time_default=time_col)
    time_col = kw.get("dt_col", time_col)        # adopt canonical name
    
    # normalize input to a dict of {label: df}
    models = normalize_model_inputs(*data)

    # default grid properties
    if grid_props is None:
        grid_props = {"linestyle": ":", "alpha": 0.7}

    # prepare figure
    fig, ax = plt.subplots(figsize=figsize)

    # diagonal
    ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect")

    # iterate models
    for name, df in models.items():
        vlog(f"Processing model '{name}'", level=1, verbose=verbose)

        if quantiles is not None:
            # --- quantile calibration ---
            if not q_prefix or not actual_col:
                raise ValueError(
                    "Must provide q_prefix and actual_col for quantiles")

            nom = []
            emp = []
            for q in quantiles:
                col = f"{q_prefix}_q{int(q*100)}"
                if col not in df.columns:
                    raise KeyError(
                        f"Missing column '{col}' in model '{name}'")
                y_pred = df[col].to_numpy()
                y_true = df[actual_col].to_numpy()
                nom.append(q)
                emp.append(np.mean(y_true <= y_pred))

            ax.plot(nom, emp, marker="o", label=name)

        elif prob_col is not None:
            # --- probability forecast calibration ---
            if prob_col not in df.columns:
                raise KeyError(f"Missing prob_col '{prob_col}'"
                               f" in model '{name}'")
            p = df[prob_col].to_numpy()
            # if actual_col not provided, try to infer
            acol = actual_col or get_actual_column_name(df)
            y = df[acol].to_numpy()

            # choose bins
            if bin_strategy == "uniform":
                bins_edges = np.linspace(0, 1, bins + 1)
            else:  # 'quantile'
                bins_edges = np.unique(
                    np.quantile(p, np.linspace(0, 1, bins + 1)))

            bin_idx = np.digitize(p, bins_edges, right=True) - 1
            # compute mean p and empirical y per bin
            bin_centers = (bins_edges[:-1] + bins_edges[1:]) / 2
            p_bar = []
            y_bar = []
            for i in range(len(bins_edges) - 1):
                mask = bin_idx == i
                if not mask.any():
                    p_bar.append(np.nan)
                    y_bar.append(np.nan)
                else:
                    p_bar.append(p[mask].mean())
                    y_bar.append(y[mask].mean())
            ax.plot(bin_centers, y_bar, marker="s", label=name)

        else:
            raise ValueError(
                "Either 'quantiles' or 'prob_col' must be provided")

    # final touches
    if show_grid:
        ax.grid(True, **grid_props)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend()

    # save or show
    if savefig:
        save_figure(
            fig,
            savefile=savefig,
            save_fmts=save_fmts,
            dpi=300,
            bbox_inches="tight",
            verbose=verbose,
            _logger=_logger, 
        )
        plt.close(fig)
    else:
        plt.show()

    return ax

[docs] def plot_forecast_by_step( df: pd.DataFrame, value_prefixes: Optional[List[str]] = None, kind: str = "dual", steps: Optional[int] =None, # step to plot step_names: Optional[Union[Dict[int, str], List[str]]] = None, spatial_cols: Optional[Tuple[str, str]] = None, max_cols: Union[int, str] = "auto", cmap: str = "viridis", cbar: str = "uniform", axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = ".png", _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, verbose: int = 1, ): """Plots forecast data, organizing subplots by forecast step. This function creates a grid of plots to visualize forecast data from a long-format DataFrame. Each row in the grid corresponds to a unique `forecast_step`. It robustly handles both spatial scatter plots (if `spatial_cols` are provided and found) and temporal line plots (as a fallback), making it versatile for different types of forecast data. Parameters ---------- df : pd.DataFrame Input long-format DataFrame. Must contain a 'forecast_step' column and value columns (e.g., 'subsidence_q50'). value_prefixes : list of str, optional The base names of metrics to plot (e.g., ``['subsidence']``). If ``None``, prefixes are auto-detected from column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison. - ``'pred_only'``: Plots only the predicted values. steps : int or list of int, optional A specific forecast step or list of steps to visualize. If ``None``, all unique steps found in the DataFrame are plotted. step_names : dict or list, optional Custom labels for the forecast steps in plot titles. - If a dict, maps step numbers to names (e.g., ``{1: 'Year 1'}``). - If a list, maps step index to names (e.g., ``['Y1', 'Y2']``, where index 0 corresponds to step 1). spatial_cols : tuple of str, optional Tuple of column names for spatial coordinates, e.g., ``('coord_x', 'coord_y')``. If provided and found, spatial scatter plots are created. If ``None`` or columns are not found, the function falls back to temporal line plots. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. If ``'auto'``, it's determined by the number of metrics (e.g., actual + quantiles) being plotted for each prefix. cmap : str, default 'viridis' The Matplotlib colormap for spatial plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls color bar scaling for spatial plots. - ``'uniform'``: All subplots share a single color scale. - ``'individual'``: Each subplot has its own color scale. axis_off : bool, default False If ``True``, turns off plot axes, ticks, and labels. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, displays a grid on the subplots. grid_props : dict, optional Properties for the grid, e.g., ``{'linestyle': '--'}``. figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the layout. savefig : str, optional Path to save the figure(s). The prefix name and '_by_step' will be appended (e.g., 'my_plot_subsidence_by_step.png'). save_fmts : str or list of str, default '.png' Format(s) to save the figure, e.g., ``['.png', '.pdf']``. verbose : int, default 1 Controls the verbosity of logging messages. Returns ------- None This function displays and/or saves Matplotlib figures and does not return any value. See Also -------- forecast_view : A related function that plots by year/time instead of by step index. Notes ----- - The function expects a long-format DataFrame where each row is a unique combination of a sample and a forecast step. - If `spatial_cols` are not found, temporal line plots are generated by plotting values against the DataFrame's index. This is most meaningful when the DataFrame is sorted by a relevant variable (e.g., a spatial coordinate) for each step. Examples -------- >>> import pandas as pd >>> # from fusionlab.plot.forecast import plot_forecast_by_step >>> data = { ... 'sample_idx': list(range(4)) * 2, ... 'forecast_step': [1] * 4 + [2] * 4, ... 'coord_x': np.random.rand(8), ... 'coord_y': np.random.rand(8), ... 'subsidence_q50': np.random.randn(8), ... 'subsidence_actual': np.random.randn(8), ... } >>> df_step_example = pd.DataFrame(data) >>> # plot_forecast_by_step( ... # df=df_step_example, ... # value_prefixes=['subsidence'], ... # spatial_cols=('coord_x', 'coord_y'), ... # step_names={1: 'Year 1 Forecast', 2: 'Year 2 Forecast'}, ... # kind='dual' ... # ) """ # 1. Detect prefixes if not provided. if value_prefixes is None: vlog("Auto-detecting value prefixes…", level=1, verbose=verbose, logger =_logger) # Define default columns to exclude from prefix detection. exclude_from_prefix = ['sample_idx', 'forecast_step'] if spatial_cols: exclude_from_prefix.extend(spatial_cols) value_prefixes = get_value_prefixes_in(df, exclude_cols=exclude_from_prefix) if not value_prefixes: raise ValueError("Could not auto-detect value prefixes.") vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose, logger =_logger) # Ensure forecast_step column exists. if "forecast_step" not in df.columns: raise ValueError("DataFrame must contain 'forecast_step'.") # Collect metric suffixes found in columns. metrics = _get_metrics_from_cols(df.columns, value_prefixes) actual_metrics = sorted([m for m in metrics if "actual" in m]) pred_metrics = sorted([m for m in metrics if "actual" not in m]) # Determine if actuals will be plotted. plot_actuals = kind == "dual" and bool(actual_metrics) # Check for spatial plotting capability. has_spatial_coords = spatial_cols and all( c in df.columns for c in spatial_cols) if kind == 'spatial' and not has_spatial_coords: warnings.warn( f"kind='spatial' but columns {spatial_cols} not found. " "Falling back to temporal line plots." ) kind = 'temporal' # Fallback kind elif kind == 'spatial': vlog("Plotting spatial data.", level=2, verbose=verbose, logger =_logger) else: # kind is 'temporal' vlog("Plotting temporal data.", level=2, verbose=verbose, logger =_logger) # Unique steps to visualize. steps_to_plot = sorted(df["forecast_step"].unique()) if steps is not None: steps = columns_manager( steps, empty_as_none= False) steps = [validate_positive_integer(v, f"Step {v}") for v in steps ] steps_to_plot = sorted (is_in_if( steps_to_plot, steps, return_intersect= True, ) ) n_rows = len(steps_to_plot) if n_rows == 0: vlog("No forecast steps to plot.", level=0, verbose=verbose, logger =_logger) return # 2. Compute global color limits if cbar is 'uniform'. vmin, vmax = None, None if cbar == "uniform": cols_for_cbar = [ f"{p}_{m}" for p in value_prefixes for m in metrics if f"{p}_{m}" in df.columns ] if cols_for_cbar: vmin = df[cols_for_cbar].dropna().min().min() vmax = df[cols_for_cbar].dropna().max().max() # Shared kwargs for plotting helpers. plot_kwargs = dict( cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off, show_grid=show_grid, grid_props=grid_props or {"linestyle": ":", "alpha": 0.7}, verbose=verbose, ) # 3. Loop over each prefix and produce a grid of sub-plots. for prefix in value_prefixes: # Select relevant metrics for this prefix. prefix_preds = [ m for m in pred_metrics if f"{prefix}_{m}" in df.columns ] prefix_actual = next( (m for m in actual_metrics if f"{prefix}_{m}" in df.columns), None, ) # Determine number of columns for the grid. if max_cols == "auto": n_cols = len(prefix_preds) + ( 1 if plot_actuals and prefix_actual else 0 ) else: n_cols = int(max_cols) if n_cols == 0: continue # Create figure and axes grid. fig_size = figsize or (n_cols * 4, n_rows * 3.5) fig, axes = plt.subplots( n_rows, n_cols, figsize=fig_size, squeeze=False, constrained_layout=True, ) fig.suptitle( f"Forecast for '{prefix.upper()}' by Step", fontsize=16, weight='bold' ) # Iterate over each forecast step (row). for r_idx, step in enumerate(steps_to_plot): step_df = df[df["forecast_step"] == step] step_lbl = f"Step {step}" # Custom label overrides. if isinstance(step_names, dict): step_lbl = step_names.get(step, step_lbl) elif isinstance(step_names, list) and step - 1 < len(step_names): step_lbl = f"Step {step}: {step_names[step - 1]}" c_idx = 0 # Plot actuals if requested and available. if plot_actuals and prefix_actual: ax = axes[r_idx][c_idx] plot_kwargs["title"] = f"Actual ({step_lbl})" plot_col = f"{prefix}_{prefix_actual}" if kind == 'spatial': _plot_spatial_subplot( ax, step_df, *spatial_cols, plot_col, **plot_kwargs ) else: _plot_temporal_subplot( ax, step_df, plot_col, **plot_kwargs ) c_idx += 1 # Plot prediction metrics. for metric in prefix_preds: if c_idx >= n_cols: break ax = axes[r_idx][c_idx] plot_kwargs["title"] = ( f"{metric.replace('_', ' ').title()} ({step_lbl})" ) plot_col = f"{prefix}_{metric}" if kind == 'spatial': _plot_spatial_subplot( ax, step_df, *spatial_cols, plot_col, **plot_kwargs ) else: _plot_temporal_subplot( ax, step_df, plot_col, **plot_kwargs ) c_idx += 1 # Blank unused axes in this row. for j in range(c_idx, n_cols): axes[r_idx][j].axis("off") # Add a single color-bar if using uniform scaling. if cbar == "uniform" and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap ), ax=axes.ravel().tolist(), orientation="vertical", shrink=0.8, pad=0.04 ) # 4. Save figure to disk if requested. if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight", _logger=_logger, ) plt.close(fig) else: plt.show()
@check_non_emptiness def forecast_view_in( forecast_df: pd.DataFrame, value_prefixes: Optional[List[str]]=None, kind: str = 'dual', view_quantiles: Optional[List[Union[str, float]]] = None, view_years: Optional[List[Union[str, int]]] = None, spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'), time_col='coord_t', max_cols: Union[int, str] = 'auto', cmap: str = 'viridis', cbar: str = 'uniform', axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = '.png', verbose: int = 1 ): """Generates and displays spatial forecast visualizations. This function creates a grid of scatter plots to visualize spatio-temporal forecast data. It can handle both long-format and wide-format DataFrames, automatically arranging plots by year and metric (e.g., different quantiles). Parameters ---------- forecast_df : pd.DataFrame Input DataFrame containing forecast data. The function auto-detects if the format is long or wide. value_prefixes : list of str, optional The base names of the metrics to plot (e.g., ``['subsidence', 'GWL']``). If ``None``, the function will attempt to automatically infer these from the DataFrame's column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison, if actuals are available. - ``'pred_only'``: Plots only the predicted values. view_quantiles : list of str or float, optional A list to filter which quantiles are displayed. Values can be floats (e.g., ``[0.1, 0.5]``) or strings (e.g., ``['q10', 'q50']``). If ``None``, all detected quantiles are plotted. view_years : list of str or int, optional A list of years to display. If ``None``, all detected years in the forecast are plotted. spatial_cols : tuple of str, default ('coord_x', 'coord_y') A tuple containing the names of the columns to be used for the x and y axes of the scatter plots. time_col : str, default 'coord_t' The name of the column representing the time dimension in a long-format DataFrame. Used by the internal format detector. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. - If ``'auto'``, the number of columns is automatically set to the number of quantiles being plotted (plus one if `kind='dual'`). - If an integer, sets a fixed number of columns for the prediction plots. If `kind='dual'`, an additional column for the actuals is added, potentially exceeding this number. cmap : str, default 'viridis' The Matplotlib colormap to use for the scatter plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls the color bar scaling: - ``'uniform'``: All subplots in a figure share a single, uniform color scale, determined by the global min/max of all plotted data. A single color bar is displayed for the entire figure. - ``'individual'``: Each subplot has its own color bar scaled to its own data range. (Note: Current implementation defaults to uniform; individual color bars would be a future enhancement). axis_off : bool, default False If ``True``, turns off the axes (ticks, labels, spines) for all subplots. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, a grid is displayed on the subplots. grid_props : dict, optional A dictionary of properties to pass to `ax.grid()` (e.g., ``{'linestyle': ':', 'alpha': 0.7}``). figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the number of rows and columns. savefig : str, optional If a file path is provided (e.g., 'my_forecast.png'), the figure(s) will be saved. The prefix name will be appended to the filename (e.g., 'my_forecast_subsidence.png'). save_fmts : str or list of str, default '.png' The format(s) to save the figure in (e.g., ``['.png', '.pdf']``). verbose : int, default 1 Controls the verbosity of logging messages. `0` is silent, `1` provides basic info, and higher values provide more detail. Returns ------- None This function does not return any value. It displays and/or saves Matplotlib figures directly. See Also -------- format_forecast_dataframe : The utility used to detect and pivot the input DataFrame. get_value_prefixes : The utility used to auto-detect metric prefixes. Notes ----- - The function creates a grid where each row corresponds to a year from `view_years`, and columns correspond to the actuals (if `kind='dual'`) and each of the selected quantiles. - If an actual value for a specific year is not available, the function will attempt to fill that plot using the most recent known 'actual' value to facilitate comparison across the row. - Currently, a separate figure is generated for each prefix in `value_prefixes`. """ # Use the format utility to ensure we have a wide DataFrame # Auto-detect value prefixes if not provided if value_prefixes is None: vlog("`value_prefixes` not provided. Auto-detecting...", level=1, verbose=verbose) value_prefixes = get_value_prefixes(forecast_df) if not value_prefixes: raise ValueError( "Could not auto-detect any value prefixes. Please " "provide them explicitly." ) vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose) value_prefixes = columns_manager(value_prefixes) df_wide = format_forecast_dataframe( df= forecast_df, to_wide=True, value_prefixes=value_prefixes, # Pass relevant pivot args if df is long id_vars=[ c for c in ['sample_idx', *spatial_cols] if c in forecast_df.columns ], time_col=time_col, static_actuals_cols=[ c for c in forecast_df.columns if c.endswith('_actual') and c.count('_') == 1 ], verbose=verbose, ) vlog("Parsing wide DataFrame columns to build plot structure.", level=1, verbose=verbose) plot_structure = _parse_wide_df_columns(df_wide, value_prefixes) # --- Filter data based on user preferences --- all_years = sorted(list( set(y for p_data in plot_structure.values() for y in p_data if y.isdigit()) )) years_to_plot = [str(y) for y in view_years] if view_years else all_years # Identify available and requested quantiles all_quantiles = sorted(list(set( suffix for p_data in plot_structure.values() for y_data in p_data.values() if isinstance(y_data, dict) for suffix in y_data if suffix.startswith('q') ))) if view_quantiles: # Normalize requested quantiles to 'qXX' format req_q_norm = [] for q in view_quantiles: if isinstance(q, float): req_q_norm.append(f"q{int(q*100)}") else: # is str req_q_norm.append(q if q.startswith('q') else f"q{q}") quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm] else: quantiles_to_plot = all_quantiles # Handle deterministic case (no quantiles found) is_deterministic = not bool(quantiles_to_plot) if is_deterministic: # Look for a default prediction column, e.g., 'subsidence_2022_pred' quantiles_to_plot = ['pred'] vlog("No quantiles found. Assuming deterministic forecast.", level=1, verbose=verbose) # --- Setup Plot Layout --- plot_actuals = (kind == 'dual') n_rows = len(years_to_plot) if n_rows == 0: vlog("No years to plot after filtering.", level=0, verbose=verbose) return # --- Prepare for plotting --- x_col, y_col = spatial_cols grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7} # Determine uniform color range if needed vmin, vmax = None, None if cbar == 'uniform': all_plot_cols = [] # Collect every column-reference that actually exists in df_wide def _collect_cols(node): """Recursively collect column names from plot_structure nodes.""" if isinstance(node, str): all_plot_cols.append(node) elif isinstance(node, dict): for sub in node.values(): _collect_cols(sub) for p_data in plot_structure.values(): _collect_cols(p_data) valid_plot_cols = [c for c in all_plot_cols if c in df_wide.columns] if valid_plot_cols: min_vals = [df_wide[c].dropna().min() for c in valid_plot_cols] max_vals = [df_wide[c].dropna().max() for c in valid_plot_cols] if any(pd.notna(v) for v in min_vals): vmin = min(v for v in min_vals if pd.notna(v)) if any(pd.notna(v) for v in max_vals): vmax = max(v for v in max_vals if pd.notna(v)) vmin = min(min_vals) vmax = max(max_vals) for prefix in value_prefixes: if max_cols == 'auto': n_cols_prefix = len(quantiles_to_plot) if plot_actuals: n_cols_prefix += 1 else: n_cols_prefix = int(max_cols) if n_cols_prefix == 0: continue figsize_prefix = figsize or (n_cols_prefix * 3.5, n_rows * 3) fig, axes = plt.subplots( n_rows, n_cols_prefix, figsize=figsize_prefix, squeeze=False, constrained_layout=True ) fig.suptitle( f"Forecast Visualization for '{prefix.upper()}'", fontsize=14, weight='bold' ) last_known_actual_col = plot_structure[prefix].get("static_actual") for row_idx, year in enumerate(years_to_plot): ax_row = axes[row_idx] col_idx = 0 if plot_actuals: actual_col = plot_structure[prefix].get( year, {}).get('actual', last_known_actual_col) if actual_col: last_known_actual_col = actual_col title = f"Actual ({year})" _plot_single_scatter( ax_row[col_idx], df_wide, x_col, y_col, actual_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ) col_idx += 1 for q_suffix in quantiles_to_plot: if col_idx >= n_cols_prefix: continue pred_col = plot_structure[prefix].get(year, {}).get(q_suffix) title = f"{q_suffix.replace('q', 'Q').capitalize()} ({year})" if is_deterministic: title = f"Prediction ({year})" _plot_single_scatter( ax_row[col_idx], df_wide, x_col, y_col, pred_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ) col_idx += 1 for i in range(col_idx, n_cols_prefix): ax_row[i].axis('off') if cbar == 'uniform' and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap), ax=axes, orientation='vertical', fraction=0.02, pad=0.04 ) if savefig: fmts = [save_fmts] if isinstance(save_fmts, str) else save_fmts base, _ = os.path.splitext(savefig) for fmt in fmts: fname = f"{base}_{prefix}{fmt if fmt.startswith('.') else '.' + fmt}" # --- ensure output directory exists -- out_dir = os.path.dirname(fname) if out_dir and not os.path.exists(out_dir): os.makedirs(out_dir, exist_ok=True) vlog(f"Saving figure to {fname}", level=1, verbose=verbose) fig.savefig(fname, dpi=300, bbox_inches="tight") plt.show()
[docs] @check_non_emptiness def forecast_view( forecast_df: pd.DataFrame, value_prefixes: Optional[List[str]] = None, kind: str = 'dual', view_quantiles: Optional[List[Union[str, float]]] = None, view_years: Optional[List[Union[str, int]]] = None, spatial_cols: Tuple[str, str] = None, time_col: str = 'coord_t', max_cols: Union[int, str] = 'auto', cmap: str = 'viridis', cbar: str = 'uniform', axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = '.png', dt_col: Optional[str] = None, show: bool =True, cumulative: Union[bool, str] = False, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, stop_check: Callable [[], bool] =None, verbose: int = 1, **kws ): """Generates and displays spatial forecast visualizations. This function creates a grid of scatter plots to visualize spatio-temporal forecast data. It can handle both long-format and wide-format DataFrames, automatically arranging plots by year and metric (e.g., different quantiles). Parameters ---------- forecast_df : pd.DataFrame Input DataFrame containing forecast data. The function auto-detects if the format is long or wide. value_prefixes : list of str, optional The base names of the metrics to plot (e.g., ``['subsidence', 'GWL']``). If ``None``, the function will attempt to automatically infer these from the DataFrame's column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison, if actuals are available. - ``'pred_only'``: Plots only the predicted values. view_quantiles : list of str or float, optional A list to filter which quantiles are displayed. Values can be floats (e.g., ``[0.1, 0.5]``) or strings (e.g., ``['q10', 'q50']``). If ``None``, all detected quantiles are plotted. view_years : list of str or int, optional A list of years to display. If ``None``, all detected years in the forecast are plotted. spatial_cols : tuple of str, default ('coord_x', 'coord_y') A tuple containing the names of the columns to be used for the x and y axes of the scatter plots. time_col : str, default 'coord_t' The name of the column representing the time dimension in a long-format DataFrame. Used by the internal format detector. ..note:: 'time_col' and 'dt_col' can be used interchangeability. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. - If ``'auto'``, the number of columns is automatically set to the number of quantiles being plotted (plus one if `kind='dual'`). - If an integer, sets a fixed number of columns for the prediction plots. If `kind='dual'`, an additional column for the actuals is added, potentially exceeding this number. cmap : str, default 'viridis' The Matplotlib colormap to use for the scatter plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls the color bar scaling: - ``'uniform'``: All subplots in a figure share a single, uniform color scale, determined by the global min/max of all plotted data. A single color bar is displayed for the entire figure. - ``'individual'``: Each subplot has its own color bar scaled to its own data range. (Note: Current implementation defaults to uniform; individual color bars would be a future enhancement). axis_off : bool, default False If ``True``, turns off the axes (ticks, labels, spines) for all subplots. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, a grid is displayed on the subplots. grid_props : dict, optional A dictionary of properties to pass to `ax.grid()` (e.g., ``{'linestyle': ':', 'alpha': 0.7}``). figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the number of rows and columns. savefig : str, optional If a file path is provided (e.g., 'my_forecast.png'), the figure(s) will be saved. The prefix name will be appended to the filename (e.g., 'my_forecast_subsidence.png'). save_fmts : str or list of str, default '.png' The format(s) to save the figure in (e.g., ``['.png', '.pdf']``). verbose : int, default 1 Controls the verbosity of logging messages. `0` is silent, `1` provides basic info, and higher values provide more detail. kws: dict, Keywords arguments for feature extensions. Returns ------- None This function does not return any value. It displays and/or saves Matplotlib figures directly. See Also -------- format_forecast_dataframe : The utility used to detect and pivot the input DataFrame. get_value_prefixes : The utility used to auto-detect metric prefixes. Notes ----- - The function creates a grid where each row corresponds to a year from `view_years`, and columns correspond to the actuals (if `kind='dual'`) and each of the selected quantiles. - If an actual value for a specific year is not available, the function will attempt to fill that plot using the most recent known 'actual' value to facilitate comparison across the row. - Currently, a separate figure is generated for each prefix in `value_prefixes`. """ # canonicalise column name kw = _coerce_dt_kw( dt_col=dt_col, time_col=time_col, _time_default=time_col) time_col = kw.get("dt_col", time_col) # adopt canonical name # Decide whether to apply cumulative sum over forecast years if isinstance(cumulative, str): cumulative_flag = cumulative.lower() in { 'cum', 'cumulative', 'cummulative', 'true', 'yes' } else: cumulative_flag = bool(cumulative) _spatial_cols = spatial_cols or [] is_q = detect_forecast_type( forecast_df, value_prefixes=value_prefixes ) if stop_check and stop_check(): raise InterruptedError("View configuration aborted.") if is_q =='deterministic': warnings.warn ( "Deterministic prediction is detected. It is recommended" " to use 'fusionlab.plot.forecast.plot_forecast_by_step'" " instead. Use at your own risk." ) if value_prefixes is None: vlog("`value_prefixes` not provided. Auto-detecting...", level=1, verbose=verbose, logger =_logger) value_prefixes = get_value_prefixes( forecast_df, exclude_cols=['sample_idx', *_spatial_cols, time_col, 'forecast_step'] ) if not value_prefixes: raise ValueError( "Could not auto-detect any value prefixes. " "Please provide them explicitly." ) vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose, logger =_logger) df_wide = format_forecast_dataframe( forecast_df, to_wide=True, value_prefixes=value_prefixes, id_vars=[c for c in ['sample_idx', *_spatial_cols] if c in forecast_df.columns], time_col=time_col, static_actuals_cols=[ c for c in forecast_df.columns if c.endswith('_actual') and c.count('_') == 1 ], spatial_cols = _spatial_cols, verbose=verbose, _logger =_logger, ) if stop_check and stop_check(): raise InterruptedError("Format plot dataframe aborted.") plot_structure = _parse_wide_df_columns(df_wide, value_prefixes) # XXX TODO # all_years = sorted([ # y for p_data in plot_structure.values() for y in p_data # if y.isdigit() # ]) # years_to_plot = [str(y) for y in view_years] if view_years else all_years all_years = sorted([ y for p_data in plot_structure.values() for y in p_data if isinstance(y, str) and y.isdigit() ]) if view_years: years_to_plot = [_normalize_year_key(y) for y in view_years] else: years_to_plot = all_years all_quantiles = sorted(list(set( s for p in plot_structure.values() for y in p.values() if isinstance(y, dict) for s in y if s.startswith('q') ))) if view_quantiles: req_q_norm = [ f"q{int(q*100)}" if isinstance(q, float) else (q if q.startswith('q') else f"q{q}") for q in view_quantiles ] quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm] else: quantiles_to_plot = all_quantiles if not quantiles_to_plot: quantiles_to_plot = ['pred'] vlog("No quantiles found. Assuming deterministic forecast.", level=1, verbose=verbose, logger =_logger) # -------------------- OPTIONAL: cumulative over years -------------------- if cumulative_flag and years_to_plot: vlog( "Applying cumulative sum over forecast years in forecast_view.", level=1, verbose=verbose, logger=_logger, ) # For each prefix (e.g. 'subsidence') and each stat ('q10', 'q50', 'pred'), # cum-sum along sorted years. for prefix, year_map in plot_structure.items(): # Keep only years that are in years_to_plot and exist in this prefix years_here = [] for y in years_to_plot: if y in year_map: try: years_here.append((float(y), y)) except ValueError: # Non-numeric labels – keep raw order years_here.append((0.0, y)) # Sort years numerically where possible years_here = [y_str for _, y_str in sorted(years_here)] if not years_here: continue for stat in quantiles_to_plot: running = None for y_str in years_here: mapping = year_map.get(y_str) if not isinstance(mapping, dict): continue col_name = mapping.get(stat) if not col_name or col_name not in df_wide.columns: continue vals = df_wide[col_name].to_numpy() if running is None: running = vals.copy() else: running = running + vals # Overwrite in-place with cumulative values df_wide[col_name] = running n_rows = len(years_to_plot) if n_rows == 0: vlog("No years to plot after filtering.", level=0, verbose=verbose, logger =_logger) return vmin, vmax = None, None if cbar == "uniform": all_plot_cols = [ c for p in plot_structure.values() for y in p.values() for c in (y.values() if isinstance(y, dict) else [y]) if c in df_wide ] numeric_plot_cols = [] for c in all_plot_cols: s = pd.to_numeric(df_wide[c], errors="coerce") if s.notna().any(): # keeps real numeric cols, drops 'mm' df_wide[c] = s numeric_plot_cols.append(c) if numeric_plot_cols: vmin = min(df_wide[c].min() for c in numeric_plot_cols) vmax = max(df_wide[c].max() for c in numeric_plot_cols) # if cbar == 'uniform': # all_plot_cols = [ # c for p in plot_structure.values() # for y in p.values() # for c in (y.values() if isinstance(y, dict) else [y]) # if c in df_wide # ] # if all_plot_cols: # min_vals = [df_wide[c].dropna().min() for c in all_plot_cols] # max_vals = [df_wide[c].dropna().max() for c in all_plot_cols] # if any(pd.notna(v) for v in min_vals): # vmin = min(v for v in min_vals if pd.notna(v)) # if any(pd.notna(v) for v in max_vals): # vmax = max(v for v in max_vals if pd.notna(v)) plot_kwargs = dict( cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off, show_grid=show_grid, grid_props=(grid_props or {'linestyle': ':', 'alpha': 0.7}), verbose=verbose ) for prefix in value_prefixes: if stop_check and stop_check(): raise InterruptedError("Forecast visualization aborted.") num_pred_cols = len(quantiles_to_plot) num_actual_cols = 1 if kind == 'dual' else 0 if max_cols == 'auto': n_cols = num_pred_cols + num_actual_cols else: n_cols = int(max_cols) if n_cols == 0: continue fig_size = figsize or (n_cols * 3.5, n_rows * 3) fig, axes = plt.subplots( n_rows, n_cols, figsize=fig_size, squeeze=False, constrained_layout=True ) fig.suptitle( f"Forecast Visualization for '{prefix.upper()}'", fontsize=14, weight='bold' ) _plot_forecast_grid( fig, axes, df_wide, plot_structure, years_to_plot, quantiles_to_plot, prefix, kind, spatial_cols, s = kws.pop('s', 10), plot_kwargs=plot_kwargs, _logger=_logger, ) if cbar == 'uniform' and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap), ax=axes.ravel().tolist(), orientation='vertical', shrink=0.8, pad=0.04 ) if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" , verbose = verbose, _logger=_logger, ) plt.close(fig) else: if show: plt.show() # for notebook debugging else: plt.close(fig)
def _normalize_year_key(y): """ Normalize a year spec (int/float/str/Timestamp) to the 4-digit string keys used in plot_structure (e.g. '2023'). """ # pandas / datetime-like try: if isinstance(y, (pd.Timestamp, datetime.date, datetime.datetime)): return f"{y.year:04d}" except Exception: pass # Numeric: 2023, 2023.0, 2023.0001 -> '2023' try: y_float = float(y) # If it is "close" to an integer, treat as year if np.isfinite(y_float) and abs(y_float - round(y_float)) < 1e-6: return f"{int(round(y_float)):04d}" except Exception: pass # String cases: '2023', '2023.0', '2023-01-01' s = str(y) # If it's something like '2023.0' try: y_float = float(s) if np.isfinite(y_float) and abs(y_float - round(y_float)) < 1e-6: return f"{int(round(y_float)):04d}" except Exception: pass # Fallback: extract first 4-digit sequence if any (e.g. from '2023-01-01') import re m = re.search(r"(\d{4})", s) if m: return m.group(1) # Last resort: return as-is return s
[docs] @check_non_emptiness def plot_forecasts( forecast_df: pd.DataFrame, target_name: str = "target", quantiles: Optional[List[float]] = None, output_dim: int = 1, kind: str = "temporal", actual_data: Optional[pd.DataFrame] = None, dt_col: Optional[str] = None, actual_target_name: Optional[str] = None, sample_ids: Optional[Union[int, List[int], str]] = "first_n", num_samples: int = 3, horizon_steps: Optional[Union[int, List[int]]] = 1, spatial_cols: Optional[List[str]] = None, max_cols: int = 2, figsize: Tuple[float, float] = (8, 4.5), scaler: Optional[Any] = None, scaler_feature_names: Optional[List[str]] = None, target_idx_in_scaler: Optional[int] = None, titles: Optional[List[str]] = None, cbar: Optional[str] = None, step_names : Dict[int, Any] =None, show_grid: bool = True, grid_props: Optional[Dict] = None, savefig: Optional [str]=None, save_fmts : Optional [Union[str, List[str]]]=".png", show: bool=True, verbose: int = 0, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, stop_check: Callable [[], bool] =None, **plot_kwargs: Any ) -> None: """ Visualizes model forecasts from a structured DataFrame. This function generates plots to visualize time series forecasts, supporting temporal line plots for individual samples/items and spatial scatter plots for specific forecast horizons. It can handle point forecasts and quantile (probabilistic) forecasts, optionally overlaying actual values for comparison. Predictions and actuals can be inverse-transformed if a scaler is provided. The input `forecast_df` is expected to be in a long format, typically generated by :func:`~fusionlab.nn.utils.format_predictions_to_dataframe`, containing 'sample_idx' and 'forecast_step' columns, along with prediction columns (e.g., '{target_name}_pred', '{target_name}_q50') and optionally actual value columns (e.g., '{target_name}_actual'). Parameters ---------- forecast_df : pd.DataFrame A pandas DataFrame containing the forecast data in long format. Must include 'sample_idx' and 'forecast_step' columns, along with prediction columns (and optionally actuals and spatial coordinate columns). target_name : str, default "target" The base name of the target variable. This is used to construct column names for predictions and actuals (e.g., "target_pred", "target_q50", "target_actual"). quantiles : List[float], optional A list of quantiles that were predicted (e.g., `[0.1, 0.5, 0.9]`). If provided, the function will plot the median quantile as a line and the range between the first and last quantile as a shaded uncertainty interval (for temporal plots). If ``None``, a point forecast is assumed. Default is ``None``. output_dim : int, default 1 The number of target variables (output dimensions) predicted at each time step. If > 1, separate subplots or plot groups may be generated for each output dimension. kind : {'temporal', 'spatial'}, default "temporal" The type of plot to generate: - ``'temporal'``: Plots forecasts over the horizon for selected samples (time series line plots). - ``'spatial'``: Creates scatter plots of forecast values over spatial coordinates for selected horizon steps. Requires `spatial_cols` to be specified. actual_data : pd.DataFrame, optional An optional DataFrame containing the true actual values for comparison. If `forecast_df` already contains actual value columns (e.g., '{target_name}_actual'), this may not be needed or can be used to supplement. If provided, `dt_col` and `actual_target_name` might be used for alignment or direct plotting. *Note: Current implementation primarily uses actuals from `forecast_df`.* dt_col : str, optional Name of the datetime column in `actual_data` or `forecast_df` if needed for x-axis labeling in temporal plots. If `forecast_df` uses 'forecast_step', this might be less critical unless aligning with a true date axis from `actual_data`. actual_target_name : str, optional The name of the target column in `actual_data` if it differs from `target_name`. If ``None``, `target_name` is assumed. sample_ids : int, List[int], str, default "first_n" Specifies which samples (based on 'sample_idx' in `forecast_df`) to plot for `kind='temporal'`. - If `int`: Plots the sample at that specific index. - If `List[int]`: Plots all samples with these indices. - If ``"first_n"``: Plots the first `num_samples`. - If ``"all"``: Plots all unique samples (can be many plots). num_samples : int, default 3 Number of samples to plot if `sample_ids="first_n"`. horizon_steps : int, List[int], str, default 1 Specifies which forecast horizon step(s) to plot for `kind='spatial'`. - If `int`: Plots the specified single step (1-indexed). - If `List[int]`: Plots all specified steps. - If ``"all"`` or ``None``: Plots all unique forecast steps available. spatial_cols : List[str], optional Required if `kind='spatial'`. A list of two column names from `forecast_df` representing the x and y spatial coordinates (e.g., `['longitude', 'latitude']`). Default is ``None``. max_cols : int, default 2 Maximum number of subplots to arrange per row in the figure. figsize : Tuple[float, float], default (8, 4.5) The size `(width, height)` in inches for **each subplot**. The total figure size will be inferred based on the number of rows and columns of subplots. scaler : Any, optional A fitted scikit-learn-like scaler object (must have an `inverse_transform` method). If provided along with `scaler_feature_names` and `target_idx_in_scaler`, prediction and actual columns related to `target_name` will be inverse-transformed before plotting. Default is ``None``. scaler_feature_names : List[str], optional A list of all feature names (in order) that the `scaler` was originally fit on. Required for targeted inverse transform if `scaler` is provided. Default is ``None``. target_idx_in_scaler : int, optional The index of the `target_name` (or the specific output dimension being plotted) within the `scaler_feature_names` list. Required for targeted inverse transform if `scaler` is provided. Default is ``None``. titles : List[str], optional A list of custom titles for the subplots. If provided, its length should match the number of subplots generated. Default is ``None`` (titles are auto-generated). cbar : {'uniform', None}, default=None Controls colour‑bar behaviour for *spatial* plots. Use ``'uniform'`` to compute a single global ``vmin`` / ``vmax`` and attach one shared colour‑bar to the figure. Any other value (or *None*) lets each subplot auto‑scale its own colour‑bar. step_names : dict[int, Any] or None, default=None Optional mapping that renames forecast‑step integers when building subplot titles—e.g. ``{1: "2021", 2: "2022"}``. Keys are coerced to ``int``; values are converted to ``str``. Missing steps fall back to *default_name* (``""`` is the step number itself). show_grid : bool, default=True Toggle the display of ``ax.grid`` in both temporal and spatial plots. Set to *False* for a cleaner look when overlaying graphics on a map. grid_props : dict or None, default=None Keyword arguments forwarded to :pyfunc:`matplotlib.axes.Axes.grid` (e.g. ``{"linestyle": ":", "alpha": 0.6}``). Ignored when *show_grid* is *False*. verbose : int, default 0 Verbosity level for logging during plot generation. - ``0``: Silent. - ``1`` or higher: Print informational messages. **plot_kwargs : Any Additional keyword arguments to pass to the underlying Matplotlib plotting functions (e.g., `ax.plot`, `ax.scatter`, `ax.fill_between`). Can be used to customize line styles, colors, marker sizes, etc. For example, `median_plot_kwargs={'color': 'blue'}`, `fill_between_kwargs={'color': 'lightblue'}`. Returns ------- None This function directly generates and shows/saves plots using Matplotlib and does not return any value. Raises ------ TypeError If `forecast_df` is not a pandas DataFrame. ValueError If essential columns like 'sample_idx' or 'forecast_step' are missing from `forecast_df`. If `kind='spatial'` is chosen but `spatial_cols` are not provided or not found in `forecast_df`. If an unsupported `kind` is specified. See Also -------- fusionlab.nn.utils.format_predictions_to_dataframe : Utility to generate the `forecast_df` expected by this function. matplotlib.pyplot.plot : Underlying plotting function for temporal forecasts. matplotlib.pyplot.scatter : Underlying plotting function for spatial forecasts. Examples -------- >>> from fusionlab.nn.utils import format_predictions_to_dataframe >>> from fusionlab.plot.forecast import plot_forecasts >>> import pandas as pd >>> import numpy as np >>> # Assume preds_point (B,H,O) and preds_quant (B,H,Q) are available >>> B, H, O, Q_len = 4, 3, 1, 3 >>> preds_point = np.random.rand(B, H, O) >>> preds_quant = np.random.rand(B, H, Q_len) >>> y_true_seq = np.random.rand(B, H, O) >>> quantiles_list = [0.1, 0.5, 0.9] >>> # Create DataFrames using format_predictions_to_dataframe >>> df_point = format_predictions_to_dataframe( ... predictions=preds_point, y_true_sequences=y_true_seq, ... target_name="value", forecast_horizon=H, output_dim=O ... ) >>> df_quant = format_predictions_to_dataframe( ... predictions=preds_quant, y_true_sequences=y_true_seq, ... target_name="value", quantiles=quantiles_list, ... forecast_horizon=H, output_dim=O ... ) >>> # Example 1: Plot temporal point forecast for first sample >>> # plot_forecasts(df_point, target_name="value", sample_ids=0) >>> # Example 2: Plot temporal quantile forecast for first 2 samples >>> # plot_forecasts(df_quant, target_name="value", ... # quantiles=quantiles_list, sample_ids="first_n", ... # num_samples=2, max_cols=1) >>> # Example 3: Spatial plot (requires spatial_cols in df_quant) >>> # Assume df_quant has 'lon' and 'lat' columns >>> # df_quant['lon'] = np.random.rand(len(df_quant)) * 10 >>> # df_quant['lat'] = np.random.rand(len(df_quant)) * 5 >>> # plot_forecasts(df_quant, target_name="value", ... # quantiles=quantiles_list, kind='spatial', ... # horizon_steps=1, spatial_cols=['lon', 'lat']) """ vlog(f"Starting forecast visualization (kind='{kind}')...", level=3, verbose=verbose, logger=_logger) if not isinstance(forecast_df, pd.DataFrame): raise TypeError("`forecast_df` must be a pandas DataFrame, typically " "the output of `format_predictions_to_dataframe`.") if 'sample_idx' not in forecast_df.columns or \ 'forecast_step' not in forecast_df.columns: raise ValueError( "`forecast_df` must contain 'sample_idx' and 'forecast_step' " "columns." ) if stop_check and stop_check(): raise InterruptedError("Plotting aborted.") # --- Data Preparation & Inverse Transform --- # Work on a copy to avoid modifying the original DataFrame df_to_plot = forecast_df.copy() actual_col_names_in_df = [] # To store names of actual columns in df_to_plot pred_col_names_in_df = [] # To store names of prediction columns # Identify prediction and actual columns based on target_name and quantiles base_pred_name = target_name base_actual_name = actual_target_name if actual_target_name \ else target_name if quantiles: # Sort quantiles to ensure consistent order for plotting (e.g., low, mid, high) sorted_quantiles = sorted([float(q) for q in quantiles]) for o_idx in range(output_dim): for q_val in sorted_quantiles: q_suffix = f"_q{int(q_val*100)}" col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += q_suffix if col_name in df_to_plot.columns: pred_col_names_in_df.append(col_name) # Actual column for this output dimension actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in df_to_plot.columns: actual_col_names_in_df.append(actual_col) else: # Point forecast for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += "_pred" if col_name in df_to_plot.columns: pred_col_names_in_df.append(col_name) actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in df_to_plot.columns: actual_col_names_in_df.append(actual_col) if not pred_col_names_in_df: warnings.warn("No prediction columns found in `forecast_df` " "based on `target_name` and `quantiles`.") # Attempt to find any column ending with _pred or _qXX pred_cols_found = [c for c in df_to_plot.columns if "_pred" in c or "_q" in c] if pred_cols_found: pred_col_names_in_df = pred_cols_found vlog(f" Auto-detected prediction columns: {pred_col_names_in_df}", level=2, verbose=verbose,logger=_logger) else: vlog(" No prediction columns could be auto-detected. " "Plotting may be limited.", level=2, verbose=verbose, logger=_logger) # Apply inverse transform if scaler is provided if scaler is not None: if scaler_feature_names is None or target_idx_in_scaler is None: warnings.warn( "Scaler provided, but `scaler_feature_names` or " "`target_idx_in_scaler` is missing. Inverse transform " "cannot be applied correctly to specific target columns.", UserWarning ) else: vlog(" Applying inverse transformation using provided scaler...", level=4, verbose=verbose, logger=_logger) cols_to_inverse_transform = pred_col_names_in_df + actual_col_names_in_df dummy_array_shape = (len(df_to_plot), len(scaler_feature_names)) for col_to_inv in cols_to_inverse_transform: if col_to_inv in df_to_plot.columns: try: dummy = np.zeros(dummy_array_shape) dummy[:, target_idx_in_scaler] = df_to_plot[col_to_inv] df_to_plot[col_to_inv] = scaler.inverse_transform( dummy )[:, target_idx_in_scaler] except Exception as e: warnings.warn( f"Failed to inverse transform column '{col_to_inv}'. " f"Plotting scaled value. Error: {e}" ) vlog(" Inverse transformation applied.", level=5, verbose=verbose, logger=_logger ) if stop_check and stop_check(): raise InterruptedError("Quantile plot configuration aborted.") # --- Select Samples/Items to Plot --- unique_sample_ids = df_to_plot['sample_idx'].unique() selected_ids_for_plot = [] if isinstance(sample_ids, str): if sample_ids.lower() == "all": selected_ids_for_plot = unique_sample_ids elif sample_ids.lower() == "first_n": selected_ids_for_plot = unique_sample_ids[:num_samples] else: warnings.warn(f"Unknown string for `sample_ids`: " f"'{sample_ids}'. Plotting first sample.") selected_ids_for_plot = unique_sample_ids[:1] elif isinstance(sample_ids, int): if sample_ids < len(unique_sample_ids): selected_ids_for_plot = [unique_sample_ids[sample_ids]] else: warnings.warn(f"sample_idx {sample_ids} out of range. " "Plotting first sample.") selected_ids_for_plot = unique_sample_ids[:1] elif isinstance(sample_ids, list): selected_ids_for_plot = [ sid for sid in sample_ids if sid in unique_sample_ids ] if len(selected_ids_for_plot) ==0: vlog("No valid sample_idx found to plot. Defaulting to first available.", level=2, verbose=verbose, logger=_logger) selected_ids_for_plot = unique_sample_ids[:1] if len(selected_ids_for_plot) ==0 : # check if is empty vlog ("No sample to plot. Returning ...", level =1, verbose=verbose) return vlog(f" Plotting for sample_idx: {selected_ids_for_plot}", level=4, verbose=verbose, logger=_logger) cmap = plot_kwargs.get('cmap', 'viridis') # --- Plotting Logic --- if kind == "temporal": num_plots = len(selected_ids_for_plot) * output_dim if num_plots == 0: vlog("No data to plot for temporal type.", level=2, verbose=verbose, logger=_logger) return n_cols_plot = min(max_cols, num_plots) n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot fig, axes = plt.subplots( n_rows_plot, n_cols_plot, figsize=(n_cols_plot * figsize[0], n_rows_plot * figsize[1]), squeeze=False # Always return 2D array for axes ) axes_flat = axes.flatten() plot_idx = 0 for s_idx in selected_ids_for_plot: sample_df = df_to_plot[df_to_plot['sample_idx'] == s_idx] if sample_df.empty: continue for o_idx in range(output_dim): if plot_idx >= len(axes_flat): break # Should not happen ax = axes_flat[plot_idx] if stop_check and stop_check(): raise InterruptedError("Temporal plot aborted.") title = f"Sample ID: {s_idx}" if output_dim > 1: title += f", Target Dim: {o_idx}" if titles and plot_idx < len(titles): title = titles[plot_idx] ax.set_title(title) # Plot actuals if available actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[actual_col], label=f'Actual {target_name}' f'{f"_{o_idx}" if output_dim > 1 else ""}', marker='o', linestyle='--' ) # Plot predictions if quantiles: median_q_val = 0.5 if 0.5 not in quantiles: # Find closest if 0.5 not present median_q_val = quantiles[len(quantiles) // 2] median_col = f"{base_pred_name}" if output_dim > 1: median_col += f"_{o_idx}" median_col += f"_q{int(median_q_val*100)}" lower_q_col = f"{base_pred_name}" if output_dim > 1: lower_q_col += f"_{o_idx}" lower_q_col += f"_q{int(sorted_quantiles[0]*100)}" upper_q_col = f"{base_pred_name}" if output_dim > 1: upper_q_col += f"_{o_idx}" upper_q_col += f"_q{int(sorted_quantiles[-1]*100)}" if median_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[median_col], label=f'Median (q{int(median_q_val*100)})', marker='x', **plot_kwargs.get("median_plot_kwargs", {}) ) if lower_q_col in sample_df.columns and \ upper_q_col in sample_df.columns: ax.fill_between( sample_df['forecast_step'], sample_df[lower_q_col], sample_df[upper_q_col], color='gray', alpha=0.3, label=f'Interval (q{int(sorted_quantiles[0]*100)}' f'-q{int(sorted_quantiles[-1]*100)})', **plot_kwargs.get("fill_between_kwargs", {}) ) else: # Point forecast pred_col = f"{base_pred_name}" if output_dim > 1: pred_col += f"_{o_idx}" pred_col += "_pred" if pred_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[pred_col], label=f'Predicted {target_name}' f'{f"_{o_idx}" if output_dim > 1 else ""}', marker='x', **plot_kwargs.get("point_plot_kwargs", {}) ) ax.set_xlabel("Forecast Step into Horizon") ax.set_ylabel( f"{target_name}{f' (Dim {o_idx})' if output_dim > 1 else ''}" ) ax.legend() ax.grid(True) plot_idx += 1 # Hide unused subplots for i in range(plot_idx, len(axes_flat)): axes_flat[i].set_visible(False) fig.tight_layout() # plt.show() elif kind == "spatial": spatial_x_col = None spatial_y_col =None if spatial_cols is not None: spatial_cols = columns_manager (spatial_cols, empty_as_none=False) if len(spatial_cols )!=2: raise ValueError( "Spatial_cols need exactly two columns (e.g., longitude," f" latitude ). Got {spatial_cols}") spatial_x_col , spatial_y_col = spatial_cols if spatial_x_col is None or spatial_y_col is None: raise ValueError( "`spatial_x_col` and `spatial_y_col` must be provided " "for `kind='spatial'`." ) if spatial_x_col not in df_to_plot.columns or \ spatial_y_col not in df_to_plot.columns: raise ValueError( f"Spatial columns '{spatial_x_col}' or '{spatial_y_col}' " "not found in forecast_df." ) steps_to_plot = [] if isinstance(horizon_steps, int): steps_to_plot = [horizon_steps] elif isinstance(horizon_steps, list): steps_to_plot = horizon_steps elif horizon_steps is None or \ str(horizon_steps).lower() == "all": # Plot all unique forecast steps present in the data steps_to_plot = sorted(df_to_plot['forecast_step'].unique()) else: raise ValueError("Invalid `horizon_steps`.") num_plots = len(steps_to_plot) * output_dim if num_plots == 0: vlog("No data/steps to plot for spatial type.", level=2, verbose=verbose, logger=_logger) return vmin, vmax = None, None if cbar=="uniform": vlog("Calculating uniform color scale for all subplots...", level=2, verbose=verbose, logger=_logger) # Filter DataFrame to only the steps being plotted df_for_scaling = df_to_plot[ df_to_plot['forecast_step'].isin(steps_to_plot) ] # Identify all columns that will be used for coloring cols_to_scan = [] if quantiles: median_q = 0.5 if 0.5 not in quantiles: median_q = sorted(quantiles)[len(quantiles) // 2] for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += f"_q{int(median_q*100)}" cols_to_scan.append(col_name) else: # Point forecasts for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += "_pred" cols_to_scan.append(col_name) existing_color_cols = [ c for c in cols_to_scan if c in df_for_scaling.columns ] if existing_color_cols: # Calculate min/max on the filtered DataFrame min_vals = [ df_for_scaling[c].dropna().min() for c in existing_color_cols] max_vals = [df_for_scaling[c].dropna().max() for c in existing_color_cols] if any(pd.notna(v) for v in min_vals): vmin = min(v for v in min_vals if pd.notna(v)) if any(pd.notna(v) for v in max_vals): vmax = max(v for v in max_vals if pd.notna(v)) if vmin is not None and vmax is not None: vlog(f"Uniform color range set: vmin={vmin:.2f}, vmax={vmax:.2f}", level=3, verbose=verbose, logger=_logger) n_cols_plot = min(max_cols, num_plots) n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot fig, axes = plt.subplots( n_rows_plot, n_cols_plot, figsize=(n_cols_plot * figsize[0], n_rows_plot * figsize[1]), squeeze=False ) axes_flat = axes.flatten() plot_idx = 0 step_names = get_step_names( steps_to_plot, step_names =step_names, default_name= "{}" ) grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7} for step in steps_to_plot: step_df = df_to_plot[df_to_plot['forecast_step'] == step] if stop_check and stop_check(): raise InterruptedError("Spatial plot aborted.") if step_df.empty: continue for o_idx in range(output_dim): if plot_idx >= len(axes_flat): break ax = axes_flat[plot_idx] # Determine column to use for color intensity color_col = None plot_title_suffix = "" if quantiles: # Use median for color median_q_val = 0.5 if 0.5 not in quantiles: median_q_val = quantiles[len(quantiles) // 2] color_col = f"{base_pred_name}" if output_dim > 1: color_col += f"_{o_idx}" color_col += f"_q{int(median_q_val*100)}" plot_title_suffix = f" (Median q{int(median_q_val*100)})" else: # Point forecast color_col = f"{base_pred_name}" if output_dim > 1: color_col += f"_{o_idx}" color_col += "_pred" if color_col not in step_df.columns: warnings.warn(f"Color column '{color_col}' not found " "for spatial plot. Skipping subplot.") plot_idx +=1 # Increment to avoid reusing subplot continue scatter_data = step_df.dropna( subset=[spatial_x_col, spatial_y_col, color_col] ) if scatter_data.empty: warnings.warn(f"No valid data to plot for step {step}, " f"output_dim {o_idx} after dropna. Skipping.") plot_idx += 1 continue # Normalize color data for better visualization if cbar !='uniform': norm = mcolors.Normalize( vmin=scatter_data[color_col].min(), vmax=scatter_data[color_col].max() ) else: norm=None # norm will be applied for uniform scale step_name = step_names.get(step, '{}') # for consistency sc = ax.scatter( scatter_data[spatial_x_col], scatter_data[spatial_y_col], c=scatter_data[color_col], cmap=cmap, norm=norm, s=plot_kwargs.get('s', 50), # Marker size alpha=plot_kwargs.get('alpha', 0.7) ) if cbar !='uniform': # use default behavior fig.colorbar(sc, ax=ax, label=f"{target_name}{plot_title_suffix}") title = step_name.format(f"Forecast Step: {step}") if output_dim > 1: title += f", Target Dim: {o_idx}" if titles and plot_idx < len(titles): title = titles[plot_idx] ax.set_title(title) ax.set_xlabel(spatial_x_col) ax.set_ylabel(spatial_y_col) if show_grid: ax.grid(True, **grid_props) else: ax.grid(False) plot_idx += 1 for i in range(plot_idx, len(axes_flat)): axes_flat[i].set_visible(False) fig.tight_layout() if cbar=='uniform': # This creates one colorbar for the whole # figure if uniform scale was used. # Adjust the main plot area to make space for the colorbar # This leaves 10% space on the right for the cbar fig.subplots_adjust(right=0.90) # Create a new axis for the colorbar # [left, bottom, width, height] in figure coordinates cax = fig.add_axes([0.92, 0.15, 0.015, 0.7]) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) mappable = ScalarMappable(norm=norm, cmap=cmap) # Add colorbar to the figure fig.colorbar( mappable, cax=cax, # ax=axes.ravel().tolist(), # Associate with all axes label=f"{target_name} {plot_title_suffix.strip()}", orientation='vertical', shrink=0.8, pad=0.04 ) else: raise ValueError( f"Unsupported `kind`: '{kind}'. " "Choose 'temporal' or 'spatial'." ) if stop_check and stop_check(): raise InterruptedError("Plot forecast aborted.") vlog("Forecast visualization complete.", level=3, verbose=verbose, logger=_logger) # 4. Save figure to disk if requested. if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" , _logger = _logger, ) plt.close(fig) else: if show: plt.show() # for notebook debugging else: plt.close(fig)
[docs] @check_non_emptiness def visualize_forecasts( forecast_df, dt_col, tname, test_data=None, eval_periods=None, mode="quantile", kind="spatial", actual_name=None, x=None, y=None, cmap="coolwarm", max_cols=3, axis="on", s=2, show_grid=True, grid_props=None, savefig=None, save_fmts =None, verbose=1, **kw ): r""" Visualize forecast results and actual test data for one or more evaluation periods. The function plots a grid of scatter plots comparing actual values with forecasted predictions. Each evaluation period yields two plots: one for actual values and one for predicted values. If multiple evaluation periods are provided, the grid layout wraps after ``max_cols`` columns. .. math:: \hat{y}_{t+i} = f\Bigl( X_{\text{static}},\;X_{\text{dynamic}},\; X_{\text{future}}\Bigr) for :math:`i = 1, \dots, N`, where :math:`N` is the forecast horizon. Parameters ---------- forecast_df : pandas.DataFrame DataFrame containing forecast results with a time column, spatial coordinates, and prediction columns. dt_col : str Name of the time column used to filter forecast results (e.g. ``"year"``). tname : str Target variable name used to construct forecast columns (e.g. ``"subsidence"``). This argument is required. eval_periods : scalar or list, optional Evaluation period(s) used to select forecast results. If set to ``None``, the function selects up to three unique periods from ``test_data[dt_col]``. mode : str, optional Forecast mode. Must be either ``"quantile"`` or ``"point"``. Default is ``"quantile"``. kind : str, optional Type of visualization. If ``"spatial"``, spatial columns are required; otherwise, the provided `x` and `y` columns are used. x : str, optional Column name for the x-axis. For non-spatial plots, this must be provided or will be inferred via ``assert_xy_in``. y : str, optional Column name for the y-axis. For non-spatial plots, this must be provided or will be inferred via ``assert_xy_in``. cmap : str, optional Colormap used for scatter plots. Default is ``"coolwarm"``. max_cols : int, optional Maximum number of evaluation periods to plot per row. If the number of periods exceeds ``max_cols``, a new row is started. axis: str, optional, Wether to keep the axis of set it to False. show_grid: bool, default=True, Visualize the grid grid_props: dict, optional Grid properties for visualizations. If none the properties is infered as ``{"linestyle":":", 'alpha':0.7}``. verbose : int, optional Verbosity level. Controls the amount of output printed. Returns ------- None The function displays the visualization plot. Examples -------- Example 1: **Spatial Visualization** In this example, we visualize the forecasted and actual values of the **subsidence** target variable, using **longitude** and **latitude** for the spatial coordinates. We visualize the results for two evaluation periods (2023 and 2024), using **quantile** mode for the forecast. >>> from fusionlab.plot.forecast import visualize_forecasts >>> forecast_results = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence_q50': [0.3, 0.4, 0.5], >>> 'subsidence': [0.35, 0.42, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> test_data = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence': [0.35, 0.41, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> visualize_forecasts( >>> forecast_df=forecast_results, >>> test_data=test_data, >>> dt_col="date", >>> tname="subsidence", >>> eval_periods=[2023, 2024], >>> mode="quantile", >>> kind="spatial", >>> cmap="coolwarm", >>> max_cols=2, >>> verbose=1 >>> ) Example 2: **Non-Spatial Visualization** In this example, we visualize the forecasted and actual values of the **subsidence** target variable in a **non-spatial** context. The columns `longitude` and `latitude` are still provided but used for non-spatial x and y axes. Evaluation is for 2023. >>> from fusionlab.plot.forecast import visualize_forecasts >>> forecast_results = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence_pred': [0.35, 0.41, 0.48], >>> 'subsidence': [0.36, 0.43, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> test_data = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence': [0.36, 0.42, 0.50], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> forecast_df_point = visualize_forecasts( >>> forecast_df=forecast_results, >>> test_data=test_data, >>> dt_col="date", >>> tname="subsidence", >>> eval_periods=[2023], >>> mode="point", >>> kind="non-spatial", >>> x="longitude", >>> y="latitude", >>> cmap="viridis", >>> max_cols=1, >>> axis="off", >>> show_grid=True, >>> grid_props={"linestyle": "--", "alpha": 0.5}, >>> verbose=2 >>> ) Notes ----- - In ``quantile`` mode, the function uses the column ``<tname>_q50`` for visualization. - In ``point`` mode, the column ``<tname>_pred`` is used. - For spatial visualizations, if ``x`` and ``y`` are not provided, they default to ``"longitude"`` and ``"latitude"``. - The evaluation period(s) are determined by filtering ``forecast_df[dt_col] == <eval_period>``. - Use ``assert_xy_in`` to validate that the x and y columns exist in the provided DataFrames. See Also -------- generate_forecast : Function to generate forecast results. coverage_score : Function to compute the coverage score. References ---------- .. [1] Kouadio L. et al., "Gofast Forecasting Model", Journal of Advanced Forecasting, 2025. (In review) """ # ****************************************************** from ..utils.ts_utils import filter_by_period # ********************************************************* # Check that forecast_df is a valid DataFrame is_frame ( forecast_df, df_only=True, objname="Forecast data", error="raise" ) if eval_periods is None: unique_periods = sorted(forecast_df[dt_col].unique()) if verbose: print("No eval_period provided; using up to three unique " + "periods from forecast data.") eval_periods = unique_periods[:3] eval_periods = columns_manager(eval_periods, to_string=True ) # Check if test_data is provided, else set it to None if test_data is not None: is_frame ( test_data, df_only=True, objname="Test data", error="raise" ) # filterby periods ensure Ensure dt_col is in Pandas datetime format test_data =filter_by_period (test_data, eval_periods, dt_col) forecast_df =filter_by_period (forecast_df, eval_periods, dt_col) # Convert eval_periods to Pandas datetime64[ns] format # # Ensure dtype match before filtering eval_periods= forecast_df[dt_col].astype(str).unique() # Determine x and y columns for spatial or non-spatial visualization if kind == "spatial": if x is None and y is None: x, y = "longitude", "latitude" check_spatial_columns(forecast_df, spatial_cols=(x, y )) x, y = assert_xy_in(x, y, data=forecast_df, asarray=False) else: if x is None or y is None: raise ValueError("For non-spatial kind, both x and y must be provided.") x, y = assert_xy_in(x, y, data=forecast_df, asarray=False) # Set prediction column based on forecast mode if mode == "quantile": pred_col = f"{tname}_q50" pred_label = f"Predicted {tname} (q50)" elif mode == "point": pred_col = f"{tname}_pred" pred_label = f"Predicted {tname}" else: raise ValueError("Mode must be either 'quantile' or 'point'.") # XXX # restore back to origin_dtype before # Loop over evaluation periods and plot df_actual = test_data if test_data is not None else forecast_df actual_name = get_actual_column_name ( df_actual, tname, actual_name=actual_name , default_to='tname', ) # Compute global min-max for color scale # for all plot. vmin = forecast_df[pred_col].min() vmax = forecast_df[pred_col].max() if test_data is not None and actual_name in test_data.columns: vmin = min(vmin, test_data[actual_name].min()) vmax = max(vmax, test_data[actual_name].max()) # Determine common periods in both forecast_df and test_data (if available) if test_data is not None: available_periods = is_in_if ( forecast_df[dt_col].astype(str).unique(), test_data[dt_col].astype(str).unique(), return_intersect=True, ) # available_periods = sorted(set(forecast_df[dt_col]) & set(test_data[dt_col])) else: # sorted(forecast_df[dt_col].astype(str).unique()) available_periods = eval_periods # Ensure the eval_periods only contain periods available in the data eval_periods = [p for p in eval_periods if p in available_periods] if len(eval_periods) == 0: raise ValueError( "[ERROR] No valid evaluation periods found in forecast or test data.") # Compute subplot grid dimensions n_periods = len(eval_periods) n_cols = min(n_periods, max_cols) n_rows = int(np.ceil(n_periods / max_cols)) # Two rows per evaluation period if # test_data is passed or is not empty total_rows = n_rows * 2 if test_data is not None else n_rows # Create subplot grid fig, axes = plt.subplots( total_rows, n_cols, figsize=(5 * n_cols, 4 * total_rows) ) # Ensure `axes` is a 2D array for consistent indexing if total_rows == 1 and n_cols == 1: pass # axes = np.array([[axes]]) elif total_rows == 1: axes = np.array([axes]) elif n_cols == 1: axes = axes.reshape(total_rows, 1) for idx, period in enumerate(eval_periods): # Try to reconvert col_idx = idx % n_cols row_idx = (idx // n_cols) * 2 if test_data is not None else (idx // n_cols) # Filter data for the current period using 'isin' for robustness forecast_subset = forecast_df[forecast_df[dt_col].isin([period])] if test_data is not None: test_subset = test_data[test_data[dt_col].isin([period])] # test_subset =filter_by_period (test_data, period, dt_col) else: test_subset = forecast_subset # If no test_data, use forecast_df itself if forecast_subset.empty or test_subset.empty: if verbose: print(f"[WARNING] No data for period {period}; skipping.") continue # Plot actual values if test_data is not None: ax_actual = axes[row_idx, col_idx] sc_actual = ax_actual.scatter( test_subset[x.name], test_subset[y.name], c=test_subset[actual_name], cmap=cmap, alpha=0.7, edgecolors='k', s=s, vmin=vmin, vmax=vmax, **kw ) ax_actual.set_title(f"Actual {tname.capitalize()} ({period})") ax_actual.set_xlabel(x.name.capitalize()) ax_actual.set_ylabel(y.name.capitalize()) if axis == "off": ax_actual.set_axis_off() else: ax_actual.set_axis_on() fig.colorbar(sc_actual, ax=ax_actual, label=tname.capitalize()) if show_grid: if grid_props is None: grid_props = {"linestyle": ":", 'alpha': 0.7} ax_actual.grid(True, **grid_props) # Plot predicted values ax_pred = axes[row_idx + 1, col_idx ] if test_data is not None else axes[row_idx, col_idx] # ax_pred = axes[row_idx + 1, col_idx] sc_pred = ax_pred.scatter( forecast_subset[x.name], forecast_subset[y.name], c=forecast_subset[pred_col], cmap=cmap, alpha=0.7, edgecolors='k', s=s, vmin=vmin, # Apply global min vmax=vmax, # Apply global max **kw ) ax_pred.set_title(f"{pred_label} ({period})") ax_pred.set_xlabel(x.name.capitalize()) ax_pred.set_ylabel(y.name.capitalize()) if axis == "off": ax_pred.set_axis_off() else: ax_pred.set_axis_on() if show_grid: if grid_props is None: grid_props = {"linestyle": ":", 'alpha': 0.7} ax_pred.grid(True, **grid_props) fig.colorbar(sc_pred, ax=ax_pred, label=pred_label) # 4. Save figure to disk if requested. plt.tight_layout() if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" ) plt.close(fig) else: plt.show()
def _get_metrics_from_cols( columns: List[str], prefixes: List[str] ) -> List[str]: """Extracts metric suffixes (e.g., q10, actual) from column names.""" metrics = set() for col in columns: for p in prefixes: if col.startswith(p + '_'): metrics.add(col[len(p)+1:]) return sorted(list(metrics)) def _parse_wide_df_columns( df_wide: pd.DataFrame, value_prefixes: List[str] ) -> Dict[str, Dict[str, Dict[str, str]]]: """Parses wide-format columns into a structured dictionary.""" plot_structure = {prefix: {} for prefix in value_prefixes} META_SUFFIXES = {"unit", "units", "uom"} # add more if needed # Regex for columns with years, e.g., GWL_2022_q10 or GWL_2022_actual pattern_year = re.compile( r'(' + '|'.join(re.escape(p) for p in value_prefixes) + r')_(\d{4})_?(.*)' ) # Regex for columns without years, e.g., GWL_q10 or GWL_pred pattern_no_year = re.compile( r'(' + '|'.join(re.escape(p) for p in value_prefixes) + r')_([a-zA-Z].*)' ) for col in df_wide.columns: match_year = pattern_year.match(col) match_no_year = pattern_no_year.match(col) if match_year: prefix, year, suffix = match_year.groups() # If suffix is empty, it's a point prediction suffix = suffix or 'pred' if suffix in META_SUFFIXES: continue if year not in plot_structure[prefix]: plot_structure[prefix][year] = {} plot_structure[prefix][year][suffix] = col elif match_no_year: # Handles columns like 'subsidence_q10', # 'subsidence_actual' prefix, suffix = match_no_year.groups() if suffix in META_SUFFIXES: continue if "static" not in plot_structure[prefix]: plot_structure[prefix]["static"] = {} plot_structure[prefix]["static"][suffix] = col elif any(col == p for p in value_prefixes): # Handles base prefix column e.g. 'subsidence' if "static" not in plot_structure[col]: plot_structure[col]["static"] ={} plot_structure[col]['static']['pred'] = col return plot_structure def _plot_spatial_subplot( ax, df, x_col, y_col, c_col, s=10, **kwargs ): """Helper to create a single spatial subplot.""" title = kwargs.get('title', '') if c_col is None or c_col not in df.columns: ax.set_title( f"{title}\n(Data not found)", fontsize=10, color='red', ) ax.axis('off') return None plot_df = df[[x_col, y_col, c_col]].dropna() if plot_df.empty: ax.set_title( f"{title}\n(No valid data)", fontsize=10, ) ax.axis('off') return None cmap = kwargs.get('cmap') vmin = kwargs.get('vmin') vmax = kwargs.get('vmax') axis_off = kwargs.get('axis_off') show_grid = kwargs.get('show_grid') grid_props = kwargs.get('grid_props') spatial_mode = kwargs.get('spatial_mode', 'scatter') hexbin_gridsize = kwargs.get('hexbin_gridsize', 40) if spatial_mode == 'hexbin': artist = ax.hexbin( plot_df[x_col].to_numpy(), plot_df[y_col].to_numpy(), C=plot_df[c_col].to_numpy(), gridsize=hexbin_gridsize, cmap=cmap, vmin=vmin, vmax=vmax, ) else: artist = ax.scatter( plot_df[x_col], plot_df[y_col], c=plot_df[c_col], cmap=cmap, vmin=vmin, vmax=vmax, s=s, edgecolors='k', linewidths=0.1, alpha=0.8, ) ax.set_title(title, fontsize=10) if axis_off: ax.axis('off') else: ax.tick_params( axis='both', which='major', labelsize=8, ) if show_grid: ax.grid(**grid_props) return artist def _plot_temporal_subplot(ax, df, value_col, title, **kwargs): """Helper to create a single line plot (fallback).""" if value_col is None or value_col not in df.columns: ax.set_title(f"{title}\n(Data not found)", fontsize=10, color='red') ax.axis('off') return plot_df = df[[value_col]].dropna() if plot_df.empty: ax.set_title(f"{title}\n(No valid data)", fontsize=10) ax.axis('off') return ax.plot(plot_df.index, plot_df[value_col]) ax.set_title(title, fontsize=10) if not kwargs.get('axis_off'): ax.tick_params(axis='both', which='major', labelsize=8) if kwargs.get('show_grid'): ax.grid(**kwargs.get('grid_props')) def _plot_forecast_grid( fig, axes, df_wide, plot_structure, years_to_plot, quantiles_to_plot, prefix, kind, spatial_cols, s, plot_kwargs, _logger ): # Check if spatial coordinates are available for plotting. has_spatial_coords = all( c in df_wide.columns for c in spatial_cols ) # Log a fallback message if spatial coordinates are not found. if not has_spatial_coords: vlog( f"Spatial columns {spatial_cols} not found. " "Falling back to temporal line plots.", level=1, verbose=plot_kwargs.get('verbose', 0), logger=_logger ) # Get the last known static actual column for fallback. last_known_actual_col = plot_structure[prefix].get( "static", {}).get("actual") # Iterate through each year to create a row of plots. for row_idx, year in enumerate(years_to_plot): ax_row = axes[row_idx] col_idx = 0 # Plot actual data if 'dual' mode is enabled. if kind == 'dual': actual_col = plot_structure[prefix].get( year, {}).get('actual', last_known_actual_col) if actual_col: last_known_actual_col = actual_col title = f"Actual ({year})" plot_kwargs['title'] = title # Choose plotting function based on coordinate availability. if has_spatial_coords: _plot_spatial_subplot( ax_row[col_idx], df_wide, *spatial_cols, actual_col, s =s, **plot_kwargs ) else: _plot_temporal_subplot( ax_row[col_idx], df_wide, actual_col, **plot_kwargs ) col_idx += 1 # Plot each requested prediction/quantile. for q_suffix in quantiles_to_plot: if col_idx >= len(ax_row): continue pred_col = plot_structure[prefix].get(year, {}).get(q_suffix) # Format the title for the subplot. title_suffix = ( q_suffix.replace('q', 'Q').capitalize() if 'q' in q_suffix else "Prediction" ) title = f"{title_suffix} ({year})" plot_kwargs['title'] = title # Choose plotting function based on coordinate availability. if has_spatial_coords: _plot_spatial_subplot( ax_row[col_idx], df_wide, *spatial_cols, pred_col, s=s, **plot_kwargs ) else: _plot_temporal_subplot( ax_row[col_idx], df_wide, pred_col, **plot_kwargs ) col_idx += 1 # Turn off any remaining unused axes in the current row. for i in range(col_idx, len(ax_row)): ax_row[i].axis('off') def _plot_single_scatter( ax, df, x_col, y_col, c_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ): """Helper to create a single scatter subplot.""" if c_col not in df.columns: ax.set_title(f"{title}\n(Data not found)", fontsize=10, color='red') ax.set_xticks([]) ax.set_yticks([]) return None # Drop NaNs for this specific plot to avoid plotting empty points plot_df = df[[x_col, y_col, c_col]].dropna() if plot_df.empty: ax.set_title(f"{title}\n(No valid data)", fontsize=10) ax.set_xticks([]) ax.set_yticks([]) return None scatter = ax.scatter( plot_df[x_col], plot_df[y_col], c=plot_df[c_col], cmap=cmap, vmin=vmin, vmax=vmax, s=10 ) ax.set_title(title, fontsize=10) if axis_off: ax.axis('off') else: ax.tick_params(axis='both', which='major', labelsize=8) if show_grid: ax.grid(**grid_props) return scatter def plot_eval_future( df_eval: Optional[pd.DataFrame] = None, df_future: Optional[pd.DataFrame] = None, target_name: str = 'subsidence', quantiles: Optional[Sequence[float]] = None, spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'), time_col: str = 'coord_t', eval_years: Optional[Sequence[Any]] = None, future_years: Optional[Sequence[Any]] = None, eval_view_quantiles: Optional[Sequence[Any]] = None, future_view_quantiles: Optional[Sequence[Any]] = None, cmap: str = 'viridis', cbar: str = 'uniform', axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict[str, Any]] = None, figsize_eval: Optional[Tuple[float, float]] = None, figsize_future: Optional[Tuple[float, float]] = None, spatial_mode: str = 'hexbin', hexbin_gridsize: int = 40, savefig_prefix: Optional[str] = None, save_fmts: Union[str, Sequence[str]] = '.png', show: bool = True, verbose: int = 1, cumulative: bool = False, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, **kws, ) -> None: """ Convenience wrapper to visualize evaluation and future forecasts. This function provides a high-level interface around :func:`forecast_view` to jointly visualize: * an **evaluation split** (``df_eval``) where actual vs predicted values are shown side-by-side for selected years, and * a **future split** (``df_future``) where only predicted quantiles are shown for future horizons. It is designed to work seamlessly with the outputs of :func:`fusionlab.nn.pinn.utils.format_and_forecast`, where the data frames contain columns such as: * ``sample_idx`` * ``forecast_step`` * ``coord_t`` (time), ``coord_x``, ``coord_y`` (spatial coords) * ``<target_name>_actual`` (evaluation only) * ``<target_name>_qXX`` or ``<target_name>_pred`` (predictions) Parameters ---------- df_eval : pandas.DataFrame, optional Evaluation DataFrame with actual and predicted values for past / validation horizons. Typically the first return value of :func:`fusionlab.nn.pinn.utils.format_and_forecast` when called on a validation or test split. If ``None`` or empty, the evaluation panel is skipped. df_future : pandas.DataFrame, optional Future forecast DataFrame with predicted values only, usually the second return value of :func:`fusionlab.nn.pinn.utils.format_and_forecast` when called on a validation or test split (for training years) or true-future split (e.g. 2023–2025). If ``None`` or empty, the future panel is skipped. target_name : str, default='subsidence' Base name of the target variable to plot. The function will look for columns that start with this prefix, e.g. ``"subsidence_q10"``, ``"subsidence_q50"``, ``"subsidence_q90"`` or ``"subsidence_pred"``. quantiles : sequence of float, optional List of quantiles corresponding to the prediction columns, e.g. ``[0.1, 0.5, 0.9]`` for ``q10``, ``q50``, ``q90``. If provided, these are used as defaults for both evaluation and future views unless overridden via ``eval_view_quantiles`` or ``future_view_quantiles``. spatial_cols : tuple of str, default=('coord_x', 'coord_y') Names of the longitude / latitude (or generic x/y) columns in ``df_eval`` and ``df_future`` used for the spatial layout. time_col : str, default='coord_t' Name of the time column in ``df_eval`` and ``df_future``. This column is used to select which years (or time stamps) to display and to structure the grid (one row per time slice). eval_years : sequence of hashable, optional List of years (or time values in ``time_col``) to visualize in the evaluation split. If ``None``, all unique values of ``time_col`` in ``df_eval`` are collected and only the last one (e.g. the final validation year) is plotted. future_years : sequence of hashable, optional List of years (or time values in ``time_col``) to visualize in the future split. If ``None``, all unique values of ``time_col`` in ``df_future`` are used. eval_view_quantiles : sequence, optional Quantiles to display for the evaluation split. Elements can be floats (e.g. ``0.1``) or strings (e.g. ``"q10"``). If ``None``, defaults to ``quantiles`` (if provided), otherwise all available quantile columns for ``target_name`` are used. future_view_quantiles : sequence, optional Quantiles to display for the future split. Same conventions as ``eval_view_quantiles``. If ``None``, defaults to ``quantiles`` (if provided), otherwise all available quantile columns are used. cmap : str, default='viridis' Matplotlib colormap name used for the value maps. cbar : {"uniform", "independent"}, default='uniform' Controls how colorbars are scaled: * ``"uniform"`` : all subplots share a global ``vmin``/``vmax`` inferred from the entire panel for consistent comparison. * ``"independent"`` : each subplot uses its own color scale. axis_off : bool, default=False If ``True``, turn off axes (ticks and frames) in the spatial plots for a cleaner, map-like appearance. show_grid : bool, default=True If ``True``, overlay a light grid on each subplot (e.g. longitude/latitude grid), controlled by ``grid_props``. grid_props : dict, optional Keyword arguments passed to the underlying grid plotting (e.g. ``{"linestyle": ":", "alpha": 0.7}``). If ``None``, a light dotted grid is used by default. figsize_eval : tuple of float, optional Figure size (width, height) in inches for the evaluation panel. If ``None``, a size is chosen automatically based on the number of columns and rows. figsize_future : tuple of float, optional Figure size (width, height) in inches for the future panel. If ``None``, a size is chosen automatically. spatial_mode : {"hexbin", "scatter"}, default='hexbin' Spatial visualization mode: * ``"hexbin"`` : use hexagonal binning (`hexbin`) to show spatial hotspots, suitable for dense datasets. * ``"scatter"`` : plot individual points as a scatter map. hexbin_gridsize : int, default=40 Grid size passed to ``matplotlib.axes.Axes.hexbin`` when ``spatial_mode="hexbin"``. Larger values yield finer spatial resolution. savefig_prefix : str, optional If provided, figures are saved using this prefix. The evaluation split is saved as ``"{prefix}_eval.*"`` and the future split as ``"{prefix}_future.*"`` with the extensions controlled by ``save_fmts``. save_fmts : str or sequence of str, default='.png' File format(s) for saving figures, e.g. ``".png"`` or ``[".png", ".pdf"]``. Ignored if ``savefig_prefix`` is ``None``. show : bool, default=True If ``True``, display figures with ``plt.show()``. If ``False`` and ``savefig_prefix`` is not ``None``, figures are only saved to disk and closed. verbose : int, default=1 Verbosity level for logging. A value of ``0`` silences all messages, higher values enable more detailed logs via :func:`fusionlab.utils.generic_utils.vlog`. cumulative : bool, default=False If ``True``, instructs :func:`forecast_view` (for the *evaluation* split) to interpret the predictions for ``target_name`` as *cumulative along the forecast horizon*, instead of per-step rates. This is useful when :func:`fusionlab.nn.pinn.utils.format_and_forecast` has been configured to output relative or absolute cumulative values (e.g. cumulative subsidence at each year). When ``False``, per-step (rate-like) predictions are visualized. Currently this flag is forwarded to the evaluation split only; the future split uses whatever representation is present in ``df_future``. _logger : logging.Logger or callable, optional Optional logger instance or callable used for internal messages. If ``None``, messages fall back to :func:`fusionlab.utils.generic_utils.vlog`. **kws Additional keyword arguments forwarded to :func:`forecast_view`, allowing fine-grained control over plotting behaviour. Returns ------- None The function creates and optionally saves matplotlib figures, but does not return any objects. Examples -------- Basic usage with validation and future forecasts: >>> from fusionlab.plot.forecast import plot_eval_future >>> df_eval, df_future = df_eval_val, df_future_val >>> plot_eval_future( ... df_eval=df_eval, ... df_future=df_future, ... target_name="subsidence", ... quantiles=[0.1, 0.5, 0.9], ... spatial_cols=("coord_x", "coord_y"), ... time_col="coord_t", ... eval_years=[2022], # last validation year ... future_years=[2023, 2024], # first two forecast years ... ) Plotting evaluation as cumulative subsidence (e.g. when ``format_and_forecast`` has been configured with cumulative outputs), while keeping the future panel with the raw values in ``df_future``: >>> from fusionlab.plot.forecast import plot_eval_future >>> plot_eval_future( ... df_eval=df_eval, ... df_future=df_future, ... target_name="subsidence", ... quantiles=[0.1, 0.5, 0.9], ... eval_view_quantiles=[0.5], # only median for eval ... future_view_quantiles=[0.1, 0.5, 0.9], ... spatial_mode="hexbin", ... cumulative=True, # show eval as cumulative ... savefig_prefix="zhongshan_subsidence_view", ... save_fmts=[".png", ".pdf"], ... show=False, ... ) """ if grid_props is None: grid_props = {'linestyle': ':', 'alpha': 0.7} if df_eval is None or df_eval.empty: msg = 'plot_eval_future: no eval data provided.' vlog(msg, level=3, verbose=verbose, logger=_logger) has_eval = False else: has_eval = True if df_future is None or df_future.empty: msg = 'plot_eval_future: no future data provided.' vlog(msg, level=3, verbose=verbose, logger=_logger) has_future = False else: has_future = True if not has_eval and not has_future: msg = 'plot_eval_future: nothing to plot.' vlog(msg, level=2, verbose=verbose, logger=_logger) return if quantiles is not None: quantiles = list(quantiles) if eval_view_quantiles is None: eval_view_quantiles = quantiles else: eval_view_quantiles = list(eval_view_quantiles) if future_view_quantiles is None: future_view_quantiles = quantiles else: future_view_quantiles = list(future_view_quantiles) # ----------------- EVAL SPLIT: actual vs predicted ----------------- if has_eval: if eval_years is None: years_eval = sorted( df_eval[time_col].dropna().unique().tolist() ) # default = last available eval year years_eval = years_eval[-1:] else: years_eval = list(eval_years) eval_save = None if savefig_prefix is not None: eval_save = f"{savefig_prefix}_eval" vlog( 'plot_eval_future: plotting eval split.', level=2, verbose=verbose, logger=_logger, ) forecast_view( forecast_df=df_eval, value_prefixes=[target_name], kind='dual', # [actual] [q10/q50/q90] layout view_quantiles=eval_view_quantiles, view_years=years_eval, spatial_cols=spatial_cols, time_col=time_col, cmap=cmap, cbar=cbar, axis_off=axis_off, show_grid=show_grid, grid_props=grid_props, figsize=figsize_eval, savefig=eval_save, save_fmts=save_fmts, show=show, verbose=verbose, _logger=_logger, spatial_mode=spatial_mode, hexbin_gridsize=hexbin_gridsize, cumulative=cumulative, **kws, ) # --------------- FUTURE SPLIT: only predictions -------------------- if has_future: if future_years is None: years_future = sorted( df_future[time_col].dropna().unique().tolist() ) else: years_future = list(future_years) future_save = None if savefig_prefix is not None: future_save = f"{savefig_prefix}_future" vlog( 'plot_eval_future: plotting future split.', level=2, verbose=verbose, logger=_logger, ) forecast_view( forecast_df=df_future, value_prefixes=[target_name], kind='pred_only', # only q's, no actual view_quantiles=future_view_quantiles, view_years=years_future, spatial_cols=spatial_cols, time_col=time_col, cmap=cmap, cbar=cbar, axis_off=axis_off, show_grid=show_grid, grid_props=grid_props, figsize=figsize_future, savefig=future_save, save_fmts=save_fmts, show=show, verbose=verbose, _logger=_logger, spatial_mode=spatial_mode, hexbin_gridsize=hexbin_gridsize, **kws, )