# -*- coding: utf‑8 -*-
# License: BSD‑3‑Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
Convenience helpers for **quick‑look** visual inspection of model
forecasts.
"""
from __future__ import annotations
import os
import re
import logging
import warnings
import datetime
from typing import (
Any, List, Optional,
Tuple, Union, Dict,
Callable,
Literal,
Mapping,
Sequence,
)
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.cm import ScalarMappable
import numpy as np
import pandas as pd
from ..core.checks import (
check_non_emptiness,
check_spatial_columns,
is_in_if
)
from ..core.handlers import columns_manager
from ..utils.forecast_utils import (
format_forecast_dataframe,
get_value_prefixes,
get_value_prefixes_in,
detect_forecast_type,
get_step_names,
)
from ..utils.calibrate import (
calibrate_forecasts,
calibrate_probability_forecast
)
from ..utils.generic_utils import (
_coerce_dt_kw,
get_actual_column_name,
vlog, save_figure,
normalize_model_inputs
)
from ..utils.validator import (
assert_xy_in, is_frame,
validate_positive_integer
)
__all__= [
"forecast_view",
"plot_forecast_by_step",
"plot_forecasts",
"visualize_forecasts",
"plot_reliability_diagram",
"plot_calibration_comparison"
]
@check_non_emptiness
def plot_calibration_comparison(
*data: Union[pd.DataFrame, Mapping[str, pd.DataFrame]],
quantiles: Optional[Sequence[float]] = None,
q_prefix: Optional[str] = None,
actual_col: Optional[str] = None,
prob_col: Optional[str] = None,
method: Literal["isotonic","logistic"] = "isotonic",
out_prefix: str = "calib",
grid_mode: Literal["unit","range"] = "unit",
grid_size: int = 1001,
group_by: Optional[str] = None,
bins: int = 10,
bin_strategy: Literal["uniform","quantile"] = "uniform",
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
figsize: Optional[Tuple[float, float]] = None,
savefig: Optional[str] = None,
save_fmts: Union[str, List[str]] = ".png",
verbose: int = 1,
_logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
) -> plt.Axes:
"""
Plot raw vs calibrated reliability curves for one or more models.
This function overlays the original ("raw") calibration curve
and the post-processed ("calibrated") curve on the same axes.
It supports both quantile-based forecasts and direct
probability forecasts.
Parameters
----------
*data : DataFrame or dict or list of DataFrames
One or more forecast tables. Each must contain either:
- Quantile columns named "{q_prefix}_qXX" plus
actuals in `actual_col`, or
- A probability column `prob_col` plus binary
outcomes in `actual_col`.
You may supply:
- A single DataFrame
- A dict mapping labels to DataFrames
- A list/tuple of DataFrames
- Multiple DataFrame args
quantiles : sequence of float, optional
Nominal quantile levels (e.g. [0.1,0.5,0.9]). If set,
`q_prefix` and `actual_col` must also be provided.
q_prefix : str, optional
Prefix used to identify quantile columns. E.g. if
q_prefix="subsidence", looks for "subsidence_q10", etc.
actual_col : str, optional
Name of the column containing true values. For
quantiles this is continuous; for probabilities it
should be 0/1 flags.
prob_col : str, optional
Name of a direct probability forecast column (0–1).
If set, `quantiles` is ignored and forecasts are
binned into `bins` intervals.
grid_mode : {'unit','range'}, default 'unit'
Which domain to build the inversion grid over:
- 'unit' -> np.linspace(0,1,grid_size)
- 'range' -> np.linspace(min(raw), max(raw), grid_size)
grid_size : int, default 1001
Number of points in the inversion grid.
group_by : str, optional
If provided, calibrate separately per `df[group_by]`
(e.g. 'forecast_step').
bins : int, default 10
Number of bins when using `prob_col`.
bin_strategy : {'uniform','quantile'}, default 'uniform'
How to form probability bins: equal-width or
equal-population.
show_grid : bool, default True
Whether to display grid lines.
grid_props : dict, optional
Keyword args passed to `Axes.grid()`. Defaults to
{'linestyle':':','alpha':0.7}.
figsize : tuple, optional
Matplotlib figure size in inches.
savefig : str, optional
File path (without extension) to save the figure.
save_fmts : str or list of str, default '.png'
File extension(s) used by `save_figure`.
verbose : int, default 1
Verbosity level for internal logging via `vlog`.
_logger : Logger or callable, optional
Logger to receive messages. If None, module logger
is used.
Returns
-------
matplotlib.axes.Axes
The Axes containing the raw and calibrated curves.
Notes
-----
- Raw calibration shows nominal vs empirical coverage
directly from model outputs.
- Calibrated curves apply isotonic or logistic scaling
to align empirical frequencies to nominal levels.
- For quantiles, each q-level is treated as a binary
classifier; we learn P(actual <= q_pred).
- For probabilities, forecasts are binned before
comparing mean forecast to observed frequency.
See Also
--------
calibrate_quantile_forecasts : Post-process quantile
forecasts via monotonic CDF inversion.
calibrate_probability_forecast : Calibrate 0–1 forecasts
with isotonic or logistic scaling.
plot_reliability_diagram : Single-curve reliability plot.
References
----------
Bröcker, J. & Smith, L. A., 2007. Scoring Probabilistic
Forecasts: The Importance of Being Proper.
Weather and Forecasting, 22(2), pp.382–388.
Wilks, D. S., 2011. Statistical Methods in the
Atmospheric Sciences (3rd ed.). Academic Press.
Examples
--------
>>> import pandas as pd
>>> from fusionlab.plot.forecast import plot_calibration_comparison
# 1) Single‐model quantile calibration (default isotonic, unit grid)
>>> df = pd.DataFrame({
... 'subsidence_q10': [1, 2, 3, 4],
... 'subsidence_q50': [2, 3, 4, 5],
... 'subsidence_q90': [3, 4, 5, 6],
... 'subsidence_actual': [1.5, 3.5, 4.2, 5.8]
... })
>>> plot_calibration_comparison(
... df,
... quantiles=[0.1, 0.5, 0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual'
... )
# 2) Single‐model probability calibration (20 quantile bins)
>>> pdf = pd.DataFrame({
... 'p_event': [0.1, 0.4, 0.8, 0.9, 0.3],
... 'event_flag': [0, 1, 1, 1, 0]
... })
>>> plot_calibration_comparison(
... pdf,
... prob_col='p_event',
... actual_col='event_flag',
... bins=20,
... bin_strategy='quantile'
... )
# 3) Compare two models via a dict of DataFrames
>>> df2 = df.copy() # pretend a second model
>>> plot_calibration_comparison(
... {'XTFT': df, 'PINN': df2},
... quantiles=[0.1, 0.5, 0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual'
... )
# 4) Multiple unnamed DataFrames as separate models
>>> plot_calibration_comparison(
... df, df2,
... quantiles=[0.1, 0.5, 0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual'
... )
# 5) Per‐horizon (step) calibration: different curves for step 1,2,3
>>> # assume df_long has 'forecast_step' 1,2,3 for each sample_idx
>>> df_long = pd.DataFrame({
... 'sample_idx': [0,0,0,1,1,1],
... 'forecast_step': [1,2,3,1,2,3],
... 'subsidence_q10': [ .1, .2, .3, .1, .2, .3],
... 'subsidence_q50': [ .5, .6, .7, .5, .6, .7],
... 'subsidence_q90': [ .9, 1.0, 1.1, .9, 1.0, 1.1],
... 'subsidence_actual': [ .2, .5, .8, .4, .7, 1.0]
... })
>>> plot_calibration_comparison(
... df_long,
... quantiles=[0.1, 0.5, 0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual',
... group_by='forecast_step'
... )
# 6) Logistic Platt‐scaling & range grid
>>> plot_calibration_comparison(
... df,
... quantiles=[0.1, 0.5, 0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual',
... method='logistic',
... grid_mode='range',
... grid_size=500
... )
# 7) List of DataFrames (treated like multiple models)
>>> plot_calibration_comparison(
... [pdf, pdf.copy()],
... prob_col='p_event',
... actual_col='event_flag'
... )
"""
models = normalize_model_inputs(*data)
if grid_props is None:
grid_props = {"linestyle": ":", "alpha": 0.7}
fig, ax = plt.subplots(figsize=figsize)
ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect")
for name, df in models.items():
vlog(f"Processing model '{name}'", level=verbose,
verbose=verbose, _logger=_logger)
# Quantile-based calibration
if quantiles and q_prefix and actual_col:
df_calib = calibrate_forecasts(
df, quantiles, q_prefix, actual_col,
method=method,
out_prefix=out_prefix,
grid_mode=grid_mode,
grid_size=grid_size,
group_by=group_by
)
nom = list(quantiles)
emp_raw = [
np.mean(df[actual_col] <= df[f"{q_prefix}_q{int(q*100)}"])
for q in quantiles
]
emp_cal = [
np.mean(df[actual_col] <= df_calib[
f"{out_prefix}_{q_prefix}_q{int(q*100)}"])
for q in quantiles
]
ax.plot(nom, emp_raw, marker="o", label=f"{name} raw")
ax.plot(nom, emp_cal, marker="x", label=f"{name} calib")
# Probability-based calibration
elif prob_col and actual_col:
df_calib = calibrate_probability_forecast(
df, prob_col, actual_col, method=method
)
y_pred = df[prob_col].to_numpy()
y_cal = df_calib[f"{prob_col}_calib"].to_numpy()
y_true = df[actual_col].to_numpy()
if bin_strategy == "uniform":
edges = np.linspace(0, 1, bins + 1)
else:
edges = np.unique(np.quantile(
y_pred, np.linspace(0, 1, bins + 1)
))
centers = (edges[:-1] + edges[1:]) / 2
def _bin_freq(vals):
idx = np.digitize(vals, edges, right=True) - 1
return [
np.mean(y_true[idx == i]) if np.any(idx == i) else np.nan
for i in range(len(edges) - 1)
]
emp_raw = _bin_freq(y_pred)
emp_cal = _bin_freq(y_cal)
ax.plot(centers, emp_raw, marker="o", label=f"{name} raw")
ax.plot(centers, emp_cal, marker="x", label=f"{name} calib")
else:
raise ValueError(
"Must provide either (quantiles+q_prefix+actual_col) "
"or (prob_col+actual_col)"
)
if show_grid:
ax.grid(True, **grid_props)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel("Nominal probability")
ax.set_ylabel("Empirical frequency")
ax.set_title("Raw vs Calibrated Reliability")
ax.legend()
if savefig:
save_figure(
fig,
savefile=savefig,
save_fmts=save_fmts,
dpi=300,
bbox_inches="tight",
verbose=verbose,
_logger=_logger,
)
plt.close(fig)
else:
plt.show()
return ax
@check_non_emptiness
def plot_reliability_diagram(
*data: Union[pd.DataFrame, Mapping[str, pd.DataFrame]],
quantiles: Optional[Sequence[float]] = None,
q_prefix: Optional[str] = None,
actual_col: Optional[str] = None,
prob_col: Optional[str] = None,
bins: int = 10,
bin_strategy: Literal["uniform","quantile"] = "uniform",
index_col: str = "sample_idx",
step_col: str = "forecast_step",
time_col: str = "coord_t",
xlabel: str = "Nominal probability",
ylabel: str = "Empirical frequency",
title: str = "Reliability Diagram",
dt_col: Optional[str] = None,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
figsize: Optional[Tuple[float, float]] = None,
savefig: Optional[str] = None,
save_fmts: Union[str, List[str]] = ".png",
verbose: int = 1,
_logger: Optional[Union[logging.Logger,
Callable[[str], None]]] = None
) -> plt.Axes:
"""
Plot reliability (calibration) curves for quantile or probability
forecasts across one or more models.
Parameters
----------
*data : DataFrame or dict or list of DataFrames
One or more forecast tables. Each table must contain either:
- quantile columns named ``{q_prefix}_qXX`` plus
``actual_col`` for true values, or
- a probability column ``prob_col`` (0–1) plus
``actual_col`` for binary events.
You may pass:
- A single DataFrame
- A dict mapping model names to DataFrames
- A list/tuple of DataFrames
- Multiple DataFrame arguments
quantiles : list of float, optional
Nominal quantile levels (e.g. [0.1,0.5,0.9]). If provided,
``q_prefix`` and ``actual_col`` must be set. The curve
plots q vs empirical coverage.
q_prefix : str, optional
Prefix of quantile columns (e.g. 'subsidence' to find
'subsidence_q10', 'subsidence_q50', etc.).
actual_col : str, optional
Column name of true values (continuous for quantiles or
binary for probabilities).
prob_col : str, optional
Name of a direct probability forecast (0–1). If set,
``quantiles`` is ignored and reliability is computed by
binning forecasts into ``bins`` groups.
bins : int, default 10
Number of bins when ``prob_col`` is used.
bin_strategy : {'uniform','quantile'}, default 'uniform'
How to choose probability bins:
- 'uniform': equal-width in [0,1]
- 'quantile': equal-population bins
index_col : str, default 'sample_idx'
Name of the sample identifier column (unused internally).
step_col : str, default 'forecast_step'
Name of the integer step column (unused unless dt_col
is auto-inferred).
time_col : str, default 'coord_t'
Name of the datetime column (unused unless dt_col
is auto-inferred).
xlabel : str, default 'Nominal probability'
X-axis label.
ylabel : str, default 'Empirical frequency'
Y-axis label.
title : str, default 'Reliability Diagram'
Plot title.
dt_col : str, optional
Alternate name for the datetime column. Handled via
``_coerce_dt_kw`` to unify with ``time_col``.
show_grid : bool, default True
Whether to display the grid.
grid_props : dict, optional
Passed to ``Axes.grid()``. Defaults to
``{'linestyle':':','alpha':0.7}``.
figsize : (float,float), optional
Figure size in inches.
savefig : str, optional
Path (no extension) to save the figure. Uses
``save_fmts`` to determine extensions.
save_fmts : str or list of str, default '.png'
File format(s) for saving (e.g. ['.png','.pdf']).
verbose : int, default 1
Verbosity level for internal logging via ``vlog``.
_logger : Logger or callable, optional
Where to send info/warnings. Defaults to the module
logger if None.
Returns
-------
matplotlib.axes.Axes
The Axes object with the reliability curves.
Notes
-----
- When using ``quantiles``, the model is evaluated at each
q-level by computing the fraction of true values ≤ predicted
q-quantile.
- When using ``prob_col``, forecasts are binned and the average
forecast vs empirical frequency is plotted.
- Multiple models can be compared by passing a dict or list of
DataFrames; each appears as a separate line.
See Also
--------
compute_quantile_coverage : Calculate empirical coverage for
quantile forecasts.
pivot_forecast_dataframe : Pivot long-format forecast tables to
wide format by step or date.
plot_probability_calibration : For continuous-valued models,
alternate calibration plot by grouping residuals.
References
----------
Bröcker, J., & Smith, L. A. (2007). Scoring Probabilistic
Forecasts: The Importance of Being Proper. Weather and
Forecasting, 22(2), 382–388.
Wilks, D. S. (2011). Statistical Methods in the Atmospheric
Sciences (3rd ed.). Academic Press.
Examples
--------
>>> import pandas as pd
>>> from fusionlab.plot.forecast import plot_reliability_diagram
>>> # 1) Single-model quantiles
>>> df = pd.DataFrame({
... 'subsidence_q10': [1,2,3,4],
... 'subsidence_q50': [2,3,4,5],
... 'subsidence_q90': [3,4,5,6],
... 'subsidence_actual': [1.5, 3.5, 4.2, 5.8]
... })
>>> plot_reliability_diagram(
... df,
... quantiles=[0.1,0.5,0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual'
... )
>>> # 2) Probability forecasts, 20 quantile bins
>>> pdf = pd.DataFrame({
... 'p_event': [0.1,0.4,0.8,0.9,0.3],
... 'event_flag': [0,1,1,1,0]
... })
>>> plot_reliability_diagram(
... pdf,
... prob_col='p_event',
... actual_col='event_flag',
... bins=20, bin_strategy='quantile'
... )
>>> # 3) Compare two models
>>> df1 = df.copy()
>>> df2 = df.copy()
>>> plot_reliability_diagram(
... {'XTFT': df1, 'PINN': df2},
... quantiles=[0.1,0.5,0.9],
... q_prefix='subsidence',
... actual_col='subsidence_actual',
... figsize=(8,5)
... )
"""
# canonicalise column name
kw = _coerce_dt_kw(
dt_col=dt_col, time_col=time_col, _time_default=time_col)
time_col = kw.get("dt_col", time_col) # adopt canonical name
# normalize input to a dict of {label: df}
models = normalize_model_inputs(*data)
# default grid properties
if grid_props is None:
grid_props = {"linestyle": ":", "alpha": 0.7}
# prepare figure
fig, ax = plt.subplots(figsize=figsize)
# diagonal
ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect")
# iterate models
for name, df in models.items():
vlog(f"Processing model '{name}'", level=1, verbose=verbose)
if quantiles is not None:
# --- quantile calibration ---
if not q_prefix or not actual_col:
raise ValueError(
"Must provide q_prefix and actual_col for quantiles")
nom = []
emp = []
for q in quantiles:
col = f"{q_prefix}_q{int(q*100)}"
if col not in df.columns:
raise KeyError(
f"Missing column '{col}' in model '{name}'")
y_pred = df[col].to_numpy()
y_true = df[actual_col].to_numpy()
nom.append(q)
emp.append(np.mean(y_true <= y_pred))
ax.plot(nom, emp, marker="o", label=name)
elif prob_col is not None:
# --- probability forecast calibration ---
if prob_col not in df.columns:
raise KeyError(f"Missing prob_col '{prob_col}'"
f" in model '{name}'")
p = df[prob_col].to_numpy()
# if actual_col not provided, try to infer
acol = actual_col or get_actual_column_name(df)
y = df[acol].to_numpy()
# choose bins
if bin_strategy == "uniform":
bins_edges = np.linspace(0, 1, bins + 1)
else: # 'quantile'
bins_edges = np.unique(
np.quantile(p, np.linspace(0, 1, bins + 1)))
bin_idx = np.digitize(p, bins_edges, right=True) - 1
# compute mean p and empirical y per bin
bin_centers = (bins_edges[:-1] + bins_edges[1:]) / 2
p_bar = []
y_bar = []
for i in range(len(bins_edges) - 1):
mask = bin_idx == i
if not mask.any():
p_bar.append(np.nan)
y_bar.append(np.nan)
else:
p_bar.append(p[mask].mean())
y_bar.append(y[mask].mean())
ax.plot(bin_centers, y_bar, marker="s", label=name)
else:
raise ValueError(
"Either 'quantiles' or 'prob_col' must be provided")
# final touches
if show_grid:
ax.grid(True, **grid_props)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend()
# save or show
if savefig:
save_figure(
fig,
savefile=savefig,
save_fmts=save_fmts,
dpi=300,
bbox_inches="tight",
verbose=verbose,
_logger=_logger,
)
plt.close(fig)
else:
plt.show()
return ax
[docs]
def plot_forecast_by_step(
df: pd.DataFrame,
value_prefixes: Optional[List[str]] = None,
kind: str = "dual",
steps: Optional[int] =None, # step to plot
step_names: Optional[Union[Dict[int, str], List[str]]] = None,
spatial_cols: Optional[Tuple[str, str]] = None,
max_cols: Union[int, str] = "auto",
cmap: str = "viridis",
cbar: str = "uniform",
axis_off: bool = False,
show_grid: bool = True,
grid_props: Optional[Dict] = None,
figsize: Optional[Tuple[float, float]] = None,
savefig: Optional[str] = None,
save_fmts: Union[str, List[str]] = ".png",
_logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
verbose: int = 1,
):
"""Plots forecast data, organizing subplots by forecast step.
This function creates a grid of plots to visualize forecast
data from a long-format DataFrame. Each row in the grid
corresponds to a unique `forecast_step`.
It robustly handles both spatial scatter plots (if `spatial_cols`
are provided and found) and temporal line plots (as a fallback),
making it versatile for different types of forecast data.
Parameters
----------
df : pd.DataFrame
Input long-format DataFrame. Must contain a 'forecast_step'
column and value columns (e.g., 'subsidence_q50').
value_prefixes : list of str, optional
The base names of metrics to plot (e.g., ``['subsidence']``).
If ``None``, prefixes are auto-detected from column names.
A separate figure is generated for each prefix.
kind : {'dual', 'pred_only'}, default 'dual'
Determines what to plot:
- ``'dual'``: Plots both actual values and predictions
side-by-side for comparison.
- ``'pred_only'``: Plots only the predicted values.
steps : int or list of int, optional
A specific forecast step or list of steps to visualize. If
``None``, all unique steps found in the DataFrame are plotted.
step_names : dict or list, optional
Custom labels for the forecast steps in plot titles.
- If a dict, maps step numbers to names (e.g., ``{1: 'Year 1'}``).
- If a list, maps step index to names (e.g., ``['Y1', 'Y2']``,
where index 0 corresponds to step 1).
spatial_cols : tuple of str, optional
Tuple of column names for spatial coordinates, e.g.,
``('coord_x', 'coord_y')``. If provided and found, spatial
scatter plots are created. If ``None`` or columns are not found,
the function falls back to temporal line plots.
max_cols : int or 'auto', default 'auto'
Controls the number of subplots per row. If ``'auto'``, it's
determined by the number of metrics (e.g., actual + quantiles)
being plotted for each prefix.
cmap : str, default 'viridis'
The Matplotlib colormap for spatial plots.
cbar : {'uniform', 'individual'}, default 'uniform'
Controls color bar scaling for spatial plots.
- ``'uniform'``: All subplots share a single color scale.
- ``'individual'``: Each subplot has its own color scale.
axis_off : bool, default False
If ``True``, turns off plot axes, ticks, and labels.
show_grid : bool, default True
If ``True`` and `axis_off` is ``False``, displays a grid on
the subplots.
grid_props : dict, optional
Properties for the grid, e.g., ``{'linestyle': '--'}``.
figsize : tuple of (float, float), optional
The size of the figure for each `value_prefix`. If ``None``,
the size is automatically calculated based on the layout.
savefig : str, optional
Path to save the figure(s). The prefix name and '_by_step'
will be appended (e.g., 'my_plot_subsidence_by_step.png').
save_fmts : str or list of str, default '.png'
Format(s) to save the figure, e.g., ``['.png', '.pdf']``.
verbose : int, default 1
Controls the verbosity of logging messages.
Returns
-------
None
This function displays and/or saves Matplotlib figures and
does not return any value.
See Also
--------
forecast_view : A related function that plots by year/time
instead of by step index.
Notes
-----
- The function expects a long-format DataFrame where each row is a
unique combination of a sample and a forecast step.
- If `spatial_cols` are not found, temporal line plots are
generated by plotting values against the DataFrame's index. This
is most meaningful when the DataFrame is sorted by a relevant
variable (e.g., a spatial coordinate) for each step.
Examples
--------
>>> import pandas as pd
>>> # from fusionlab.plot.forecast import plot_forecast_by_step
>>> data = {
... 'sample_idx': list(range(4)) * 2,
... 'forecast_step': [1] * 4 + [2] * 4,
... 'coord_x': np.random.rand(8),
... 'coord_y': np.random.rand(8),
... 'subsidence_q50': np.random.randn(8),
... 'subsidence_actual': np.random.randn(8),
... }
>>> df_step_example = pd.DataFrame(data)
>>> # plot_forecast_by_step(
... # df=df_step_example,
... # value_prefixes=['subsidence'],
... # spatial_cols=('coord_x', 'coord_y'),
... # step_names={1: 'Year 1 Forecast', 2: 'Year 2 Forecast'},
... # kind='dual'
... # )
"""
# 1. Detect prefixes if not provided.
if value_prefixes is None:
vlog("Auto-detecting value prefixes…",
level=1, verbose=verbose, logger =_logger)
# Define default columns to exclude from prefix detection.
exclude_from_prefix = ['sample_idx', 'forecast_step']
if spatial_cols:
exclude_from_prefix.extend(spatial_cols)
value_prefixes = get_value_prefixes_in(df, exclude_cols=exclude_from_prefix)
if not value_prefixes:
raise ValueError("Could not auto-detect value prefixes.")
vlog(f"Detected prefixes: {value_prefixes}",
level=2, verbose=verbose, logger =_logger)
# Ensure forecast_step column exists.
if "forecast_step" not in df.columns:
raise ValueError("DataFrame must contain 'forecast_step'.")
# Collect metric suffixes found in columns.
metrics = _get_metrics_from_cols(df.columns, value_prefixes)
actual_metrics = sorted([m for m in metrics if "actual" in m])
pred_metrics = sorted([m for m in metrics if "actual" not in m])
# Determine if actuals will be plotted.
plot_actuals = kind == "dual" and bool(actual_metrics)
# Check for spatial plotting capability.
has_spatial_coords = spatial_cols and all(
c in df.columns for c in spatial_cols)
if kind == 'spatial' and not has_spatial_coords:
warnings.warn(
f"kind='spatial' but columns {spatial_cols} not found. "
"Falling back to temporal line plots."
)
kind = 'temporal' # Fallback kind
elif kind == 'spatial':
vlog("Plotting spatial data.", level=2, verbose=verbose,
logger =_logger)
else: # kind is 'temporal'
vlog("Plotting temporal data.", level=2, verbose=verbose,
logger =_logger)
# Unique steps to visualize.
steps_to_plot = sorted(df["forecast_step"].unique())
if steps is not None:
steps = columns_manager( steps, empty_as_none= False)
steps = [validate_positive_integer(v, f"Step {v}")
for v in steps ]
steps_to_plot = sorted (is_in_if(
steps_to_plot, steps,
return_intersect= True,
)
)
n_rows = len(steps_to_plot)
if n_rows == 0:
vlog("No forecast steps to plot.", level=0, verbose=verbose,
logger =_logger)
return
# 2. Compute global color limits if cbar is 'uniform'.
vmin, vmax = None, None
if cbar == "uniform":
cols_for_cbar = [
f"{p}_{m}" for p in value_prefixes for m in metrics
if f"{p}_{m}" in df.columns
]
if cols_for_cbar:
vmin = df[cols_for_cbar].dropna().min().min()
vmax = df[cols_for_cbar].dropna().max().max()
# Shared kwargs for plotting helpers.
plot_kwargs = dict(
cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off,
show_grid=show_grid,
grid_props=grid_props or {"linestyle": ":", "alpha": 0.7},
verbose=verbose,
)
# 3. Loop over each prefix and produce a grid of sub-plots.
for prefix in value_prefixes:
# Select relevant metrics for this prefix.
prefix_preds = [
m for m in pred_metrics if f"{prefix}_{m}" in df.columns
]
prefix_actual = next(
(m for m in actual_metrics if f"{prefix}_{m}" in df.columns),
None,
)
# Determine number of columns for the grid.
if max_cols == "auto":
n_cols = len(prefix_preds) + (
1 if plot_actuals and prefix_actual else 0
)
else:
n_cols = int(max_cols)
if n_cols == 0:
continue
# Create figure and axes grid.
fig_size = figsize or (n_cols * 4, n_rows * 3.5)
fig, axes = plt.subplots(
n_rows, n_cols, figsize=fig_size, squeeze=False,
constrained_layout=True,
)
fig.suptitle(
f"Forecast for '{prefix.upper()}' by Step",
fontsize=16, weight='bold'
)
# Iterate over each forecast step (row).
for r_idx, step in enumerate(steps_to_plot):
step_df = df[df["forecast_step"] == step]
step_lbl = f"Step {step}"
# Custom label overrides.
if isinstance(step_names, dict):
step_lbl = step_names.get(step, step_lbl)
elif isinstance(step_names, list) and step - 1 < len(step_names):
step_lbl = f"Step {step}: {step_names[step - 1]}"
c_idx = 0
# Plot actuals if requested and available.
if plot_actuals and prefix_actual:
ax = axes[r_idx][c_idx]
plot_kwargs["title"] = f"Actual ({step_lbl})"
plot_col = f"{prefix}_{prefix_actual}"
if kind == 'spatial':
_plot_spatial_subplot(
ax, step_df, *spatial_cols, plot_col, **plot_kwargs
)
else:
_plot_temporal_subplot(
ax, step_df, plot_col, **plot_kwargs
)
c_idx += 1
# Plot prediction metrics.
for metric in prefix_preds:
if c_idx >= n_cols: break
ax = axes[r_idx][c_idx]
plot_kwargs["title"] = (
f"{metric.replace('_', ' ').title()} ({step_lbl})"
)
plot_col = f"{prefix}_{metric}"
if kind == 'spatial':
_plot_spatial_subplot(
ax, step_df, *spatial_cols, plot_col, **plot_kwargs
)
else:
_plot_temporal_subplot(
ax, step_df, plot_col, **plot_kwargs
)
c_idx += 1
# Blank unused axes in this row.
for j in range(c_idx, n_cols):
axes[r_idx][j].axis("off")
# Add a single color-bar if using uniform scaling.
if cbar == "uniform" and vmin is not None and vmax is not None:
fig.colorbar(
ScalarMappable(
norm=mcolors.Normalize(vmin=vmin, vmax=vmax),
cmap=cmap
),
ax=axes.ravel().tolist(), orientation="vertical",
shrink=0.8, pad=0.04
)
# 4. Save figure to disk if requested.
if savefig:
save_figure (
fig,
savefile = savefig,
save_fmts= save_fmts,
dpi=300,
bbox_inches="tight",
_logger=_logger,
)
plt.close(fig)
else:
plt.show()
@check_non_emptiness
def forecast_view_in(
forecast_df: pd.DataFrame,
value_prefixes: Optional[List[str]]=None,
kind: str = 'dual',
view_quantiles: Optional[List[Union[str, float]]] = None,
view_years: Optional[List[Union[str, int]]] = None,
spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'),
time_col='coord_t',
max_cols: Union[int, str] = 'auto',
cmap: str = 'viridis',
cbar: str = 'uniform',
axis_off: bool = False,
show_grid: bool = True,
grid_props: Optional[Dict] = None,
figsize: Optional[Tuple[float, float]] = None,
savefig: Optional[str] = None,
save_fmts: Union[str, List[str]] = '.png',
verbose: int = 1
):
"""Generates and displays spatial forecast visualizations.
This function creates a grid of scatter plots to visualize
spatio-temporal forecast data. It can handle both long-format
and wide-format DataFrames, automatically arranging plots by
year and metric (e.g., different quantiles).
Parameters
----------
forecast_df : pd.DataFrame
Input DataFrame containing forecast data. The function
auto-detects if the format is long or wide.
value_prefixes : list of str, optional
The base names of the metrics to plot (e.g.,
``['subsidence', 'GWL']``). If ``None``, the function
will attempt to automatically infer these from the
DataFrame's column names. A separate figure is generated
for each prefix.
kind : {'dual', 'pred_only'}, default 'dual'
Determines what to plot:
- ``'dual'``: Plots both actual values and predictions
side-by-side for comparison, if actuals are available.
- ``'pred_only'``: Plots only the predicted values.
view_quantiles : list of str or float, optional
A list to filter which quantiles are displayed. Values can be
floats (e.g., ``[0.1, 0.5]``) or strings (e.g.,
``['q10', 'q50']``). If ``None``, all detected quantiles
are plotted.
view_years : list of str or int, optional
A list of years to display. If ``None``, all detected years
in the forecast are plotted.
spatial_cols : tuple of str, default ('coord_x', 'coord_y')
A tuple containing the names of the columns to be used for
the x and y axes of the scatter plots.
time_col : str, default 'coord_t'
The name of the column representing the time dimension in a
long-format DataFrame. Used by the internal format detector.
max_cols : int or 'auto', default 'auto'
Controls the number of subplots per row.
- If ``'auto'``, the number of columns is automatically set
to the number of quantiles being plotted (plus one if
`kind='dual'`).
- If an integer, sets a fixed number of columns for the
prediction plots. If `kind='dual'`, an additional column
for the actuals is added, potentially exceeding this number.
cmap : str, default 'viridis'
The Matplotlib colormap to use for the scatter plots.
cbar : {'uniform', 'individual'}, default 'uniform'
Controls the color bar scaling:
- ``'uniform'``: All subplots in a figure share a single,
uniform color scale, determined by the global min/max of
all plotted data. A single color bar is displayed for the
entire figure.
- ``'individual'``: Each subplot has its own color bar scaled
to its own data range. (Note: Current implementation
defaults to uniform; individual color bars would be a
future enhancement).
axis_off : bool, default False
If ``True``, turns off the axes (ticks, labels, spines) for
all subplots.
show_grid : bool, default True
If ``True`` and `axis_off` is ``False``, a grid is displayed
on the subplots.
grid_props : dict, optional
A dictionary of properties to pass to `ax.grid()` (e.g.,
``{'linestyle': ':', 'alpha': 0.7}``).
figsize : tuple of (float, float), optional
The size of the figure for each `value_prefix`. If ``None``,
the size is automatically calculated based on the number of
rows and columns.
savefig : str, optional
If a file path is provided (e.g., 'my_forecast.png'), the
figure(s) will be saved. The prefix name will be appended
to the filename (e.g., 'my_forecast_subsidence.png').
save_fmts : str or list of str, default '.png'
The format(s) to save the figure in (e.g., ``['.png', '.pdf']``).
verbose : int, default 1
Controls the verbosity of logging messages. `0` is silent,
`1` provides basic info, and higher values provide more detail.
Returns
-------
None
This function does not return any value. It displays and/or
saves Matplotlib figures directly.
See Also
--------
format_forecast_dataframe : The utility used to detect and pivot
the input DataFrame.
get_value_prefixes : The utility used to auto-detect metric prefixes.
Notes
-----
- The function creates a grid where each row corresponds to a year
from `view_years`, and columns correspond to the actuals (if
`kind='dual'`) and each of the selected quantiles.
- If an actual value for a specific year is not available, the
function will attempt to fill that plot using the most recent
known 'actual' value to facilitate comparison across the row.
- Currently, a separate figure is generated for each prefix in
`value_prefixes`.
"""
# Use the format utility to ensure we have a wide DataFrame
# Auto-detect value prefixes if not provided
if value_prefixes is None:
vlog("`value_prefixes` not provided. Auto-detecting...",
level=1, verbose=verbose)
value_prefixes = get_value_prefixes(forecast_df)
if not value_prefixes:
raise ValueError(
"Could not auto-detect any value prefixes. Please "
"provide them explicitly."
)
vlog(f"Detected prefixes: {value_prefixes}",
level=2, verbose=verbose)
value_prefixes = columns_manager(value_prefixes)
df_wide = format_forecast_dataframe(
df= forecast_df,
to_wide=True,
value_prefixes=value_prefixes,
# Pass relevant pivot args if df is long
id_vars=[
c for c in ['sample_idx', *spatial_cols]
if c in forecast_df.columns
],
time_col=time_col,
static_actuals_cols=[
c for c in forecast_df.columns
if c.endswith('_actual') and c.count('_') == 1
],
verbose=verbose,
)
vlog("Parsing wide DataFrame columns to build plot structure.",
level=1, verbose=verbose)
plot_structure = _parse_wide_df_columns(df_wide, value_prefixes)
# --- Filter data based on user preferences ---
all_years = sorted(list(
set(y for p_data in plot_structure.values()
for y in p_data if y.isdigit())
))
years_to_plot = [str(y) for y in view_years] if view_years else all_years
# Identify available and requested quantiles
all_quantiles = sorted(list(set(
suffix for p_data in plot_structure.values()
for y_data in p_data.values() if isinstance(y_data, dict)
for suffix in y_data if suffix.startswith('q')
)))
if view_quantiles:
# Normalize requested quantiles to 'qXX' format
req_q_norm = []
for q in view_quantiles:
if isinstance(q, float):
req_q_norm.append(f"q{int(q*100)}")
else: # is str
req_q_norm.append(q if q.startswith('q') else f"q{q}")
quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm]
else:
quantiles_to_plot = all_quantiles
# Handle deterministic case (no quantiles found)
is_deterministic = not bool(quantiles_to_plot)
if is_deterministic:
# Look for a default prediction column, e.g., 'subsidence_2022_pred'
quantiles_to_plot = ['pred']
vlog("No quantiles found. Assuming deterministic forecast.",
level=1, verbose=verbose)
# --- Setup Plot Layout ---
plot_actuals = (kind == 'dual')
n_rows = len(years_to_plot)
if n_rows == 0:
vlog("No years to plot after filtering.", level=0,
verbose=verbose)
return
# --- Prepare for plotting ---
x_col, y_col = spatial_cols
grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7}
# Determine uniform color range if needed
vmin, vmax = None, None
if cbar == 'uniform':
all_plot_cols = []
# Collect every column-reference that actually exists in df_wide
def _collect_cols(node):
"""Recursively collect column names from plot_structure nodes."""
if isinstance(node, str):
all_plot_cols.append(node)
elif isinstance(node, dict):
for sub in node.values():
_collect_cols(sub)
for p_data in plot_structure.values():
_collect_cols(p_data)
valid_plot_cols = [c for c in all_plot_cols if c in df_wide.columns]
if valid_plot_cols:
min_vals = [df_wide[c].dropna().min() for c in valid_plot_cols]
max_vals = [df_wide[c].dropna().max() for c in valid_plot_cols]
if any(pd.notna(v) for v in min_vals):
vmin = min(v for v in min_vals if pd.notna(v))
if any(pd.notna(v) for v in max_vals):
vmax = max(v for v in max_vals if pd.notna(v))
vmin = min(min_vals)
vmax = max(max_vals)
for prefix in value_prefixes:
if max_cols == 'auto':
n_cols_prefix = len(quantiles_to_plot)
if plot_actuals: n_cols_prefix += 1
else:
n_cols_prefix = int(max_cols)
if n_cols_prefix == 0:
continue
figsize_prefix = figsize or (n_cols_prefix * 3.5, n_rows * 3)
fig, axes = plt.subplots(
n_rows, n_cols_prefix, figsize=figsize_prefix,
squeeze=False, constrained_layout=True
)
fig.suptitle(
f"Forecast Visualization for '{prefix.upper()}'",
fontsize=14, weight='bold'
)
last_known_actual_col = plot_structure[prefix].get("static_actual")
for row_idx, year in enumerate(years_to_plot):
ax_row = axes[row_idx]
col_idx = 0
if plot_actuals:
actual_col = plot_structure[prefix].get(
year, {}).get('actual', last_known_actual_col)
if actual_col: last_known_actual_col = actual_col
title = f"Actual ({year})"
_plot_single_scatter(
ax_row[col_idx], df_wide, x_col, y_col,
actual_col, cmap, vmin, vmax, title,
axis_off, show_grid, grid_props
)
col_idx += 1
for q_suffix in quantiles_to_plot:
if col_idx >= n_cols_prefix: continue
pred_col = plot_structure[prefix].get(year, {}).get(q_suffix)
title = f"{q_suffix.replace('q', 'Q').capitalize()} ({year})"
if is_deterministic: title = f"Prediction ({year})"
_plot_single_scatter(
ax_row[col_idx], df_wide, x_col, y_col, pred_col,
cmap, vmin, vmax, title, axis_off, show_grid, grid_props
)
col_idx += 1
for i in range(col_idx, n_cols_prefix):
ax_row[i].axis('off')
if cbar == 'uniform' and vmin is not None and vmax is not None:
fig.colorbar(
ScalarMappable(
norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap),
ax=axes, orientation='vertical', fraction=0.02, pad=0.04
)
if savefig:
fmts = [save_fmts] if isinstance(save_fmts, str) else save_fmts
base, _ = os.path.splitext(savefig)
for fmt in fmts:
fname = f"{base}_{prefix}{fmt if fmt.startswith('.') else '.' + fmt}"
# --- ensure output directory exists --
out_dir = os.path.dirname(fname)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
vlog(f"Saving figure to {fname}", level=1, verbose=verbose)
fig.savefig(fname, dpi=300, bbox_inches="tight")
plt.show()
[docs]
@check_non_emptiness
def forecast_view(
forecast_df: pd.DataFrame,
value_prefixes: Optional[List[str]] = None,
kind: str = 'dual',
view_quantiles: Optional[List[Union[str, float]]] = None,
view_years: Optional[List[Union[str, int]]] = None,
spatial_cols: Tuple[str, str] = None,
time_col: str = 'coord_t',
max_cols: Union[int, str] = 'auto',
cmap: str = 'viridis',
cbar: str = 'uniform',
axis_off: bool = False,
show_grid: bool = True,
grid_props: Optional[Dict] = None,
figsize: Optional[Tuple[float, float]] = None,
savefig: Optional[str] = None,
save_fmts: Union[str, List[str]] = '.png',
dt_col: Optional[str] = None,
show: bool =True,
cumulative: Union[bool, str] = False,
_logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
stop_check: Callable [[], bool] =None,
verbose: int = 1,
**kws
):
"""Generates and displays spatial forecast visualizations.
This function creates a grid of scatter plots to visualize
spatio-temporal forecast data. It can handle both long-format
and wide-format DataFrames, automatically arranging plots by
year and metric (e.g., different quantiles).
Parameters
----------
forecast_df : pd.DataFrame
Input DataFrame containing forecast data. The function
auto-detects if the format is long or wide.
value_prefixes : list of str, optional
The base names of the metrics to plot (e.g.,
``['subsidence', 'GWL']``). If ``None``, the function
will attempt to automatically infer these from the
DataFrame's column names. A separate figure is generated
for each prefix.
kind : {'dual', 'pred_only'}, default 'dual'
Determines what to plot:
- ``'dual'``: Plots both actual values and predictions
side-by-side for comparison, if actuals are available.
- ``'pred_only'``: Plots only the predicted values.
view_quantiles : list of str or float, optional
A list to filter which quantiles are displayed. Values can be
floats (e.g., ``[0.1, 0.5]``) or strings (e.g.,
``['q10', 'q50']``). If ``None``, all detected quantiles
are plotted.
view_years : list of str or int, optional
A list of years to display. If ``None``, all detected years
in the forecast are plotted.
spatial_cols : tuple of str, default ('coord_x', 'coord_y')
A tuple containing the names of the columns to be used for
the x and y axes of the scatter plots.
time_col : str, default 'coord_t'
The name of the column representing the time dimension in a
long-format DataFrame. Used by the internal format detector.
..note::
'time_col' and 'dt_col' can be used interchangeability.
max_cols : int or 'auto', default 'auto'
Controls the number of subplots per row.
- If ``'auto'``, the number of columns is automatically set
to the number of quantiles being plotted (plus one if
`kind='dual'`).
- If an integer, sets a fixed number of columns for the
prediction plots. If `kind='dual'`, an additional column
for the actuals is added, potentially exceeding this number.
cmap : str, default 'viridis'
The Matplotlib colormap to use for the scatter plots.
cbar : {'uniform', 'individual'}, default 'uniform'
Controls the color bar scaling:
- ``'uniform'``: All subplots in a figure share a single,
uniform color scale, determined by the global min/max of
all plotted data. A single color bar is displayed for the
entire figure.
- ``'individual'``: Each subplot has its own color bar scaled
to its own data range. (Note: Current implementation
defaults to uniform; individual color bars would be a
future enhancement).
axis_off : bool, default False
If ``True``, turns off the axes (ticks, labels, spines) for
all subplots.
show_grid : bool, default True
If ``True`` and `axis_off` is ``False``, a grid is displayed
on the subplots.
grid_props : dict, optional
A dictionary of properties to pass to `ax.grid()` (e.g.,
``{'linestyle': ':', 'alpha': 0.7}``).
figsize : tuple of (float, float), optional
The size of the figure for each `value_prefix`. If ``None``,
the size is automatically calculated based on the number of
rows and columns.
savefig : str, optional
If a file path is provided (e.g., 'my_forecast.png'), the
figure(s) will be saved. The prefix name will be appended
to the filename (e.g., 'my_forecast_subsidence.png').
save_fmts : str or list of str, default '.png'
The format(s) to save the figure in (e.g., ``['.png', '.pdf']``).
verbose : int, default 1
Controls the verbosity of logging messages. `0` is silent,
`1` provides basic info, and higher values provide more detail.
kws: dict,
Keywords arguments for feature extensions.
Returns
-------
None
This function does not return any value. It displays and/or
saves Matplotlib figures directly.
See Also
--------
format_forecast_dataframe : The utility used to detect and pivot
the input DataFrame.
get_value_prefixes : The utility used to auto-detect metric prefixes.
Notes
-----
- The function creates a grid where each row corresponds to a year
from `view_years`, and columns correspond to the actuals (if
`kind='dual'`) and each of the selected quantiles.
- If an actual value for a specific year is not available, the
function will attempt to fill that plot using the most recent
known 'actual' value to facilitate comparison across the row.
- Currently, a separate figure is generated for each prefix in
`value_prefixes`.
"""
# canonicalise column name
kw = _coerce_dt_kw(
dt_col=dt_col, time_col=time_col, _time_default=time_col)
time_col = kw.get("dt_col", time_col) # adopt canonical name
# Decide whether to apply cumulative sum over forecast years
if isinstance(cumulative, str):
cumulative_flag = cumulative.lower() in {
'cum', 'cumulative', 'cummulative', 'true', 'yes'
}
else:
cumulative_flag = bool(cumulative)
_spatial_cols = spatial_cols or []
is_q = detect_forecast_type(
forecast_df, value_prefixes=value_prefixes
)
if stop_check and stop_check():
raise InterruptedError("View configuration aborted.")
if is_q =='deterministic':
warnings.warn (
"Deterministic prediction is detected. It is recommended"
" to use 'fusionlab.plot.forecast.plot_forecast_by_step'"
" instead. Use at your own risk."
)
if value_prefixes is None:
vlog("`value_prefixes` not provided. Auto-detecting...",
level=1, verbose=verbose, logger =_logger)
value_prefixes = get_value_prefixes(
forecast_df,
exclude_cols=['sample_idx', *_spatial_cols,
time_col, 'forecast_step']
)
if not value_prefixes:
raise ValueError(
"Could not auto-detect any value prefixes. "
"Please provide them explicitly."
)
vlog(f"Detected prefixes: {value_prefixes}",
level=2, verbose=verbose, logger =_logger)
df_wide = format_forecast_dataframe(
forecast_df, to_wide=True, value_prefixes=value_prefixes,
id_vars=[c for c in ['sample_idx', *_spatial_cols]
if c in forecast_df.columns],
time_col=time_col,
static_actuals_cols=[
c for c in forecast_df.columns
if c.endswith('_actual') and c.count('_') == 1
],
spatial_cols = _spatial_cols,
verbose=verbose,
_logger =_logger,
)
if stop_check and stop_check():
raise InterruptedError("Format plot dataframe aborted.")
plot_structure = _parse_wide_df_columns(df_wide, value_prefixes)
# XXX TODO
# all_years = sorted([
# y for p_data in plot_structure.values() for y in p_data
# if y.isdigit()
# ])
# years_to_plot = [str(y) for y in view_years] if view_years else all_years
all_years = sorted([
y for p_data in plot_structure.values() for y in p_data
if isinstance(y, str) and y.isdigit()
])
if view_years:
years_to_plot = [_normalize_year_key(y) for y in view_years]
else:
years_to_plot = all_years
all_quantiles = sorted(list(set(
s for p in plot_structure.values()
for y in p.values() if isinstance(y, dict)
for s in y if s.startswith('q')
)))
if view_quantiles:
req_q_norm = [
f"q{int(q*100)}" if isinstance(q, float)
else (q if q.startswith('q') else f"q{q}")
for q in view_quantiles
]
quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm]
else:
quantiles_to_plot = all_quantiles
if not quantiles_to_plot:
quantiles_to_plot = ['pred']
vlog("No quantiles found. Assuming deterministic forecast.",
level=1, verbose=verbose, logger =_logger)
# -------------------- OPTIONAL: cumulative over years --------------------
if cumulative_flag and years_to_plot:
vlog(
"Applying cumulative sum over forecast years in forecast_view.",
level=1,
verbose=verbose,
logger=_logger,
)
# For each prefix (e.g. 'subsidence') and each stat ('q10', 'q50', 'pred'),
# cum-sum along sorted years.
for prefix, year_map in plot_structure.items():
# Keep only years that are in years_to_plot and exist in this prefix
years_here = []
for y in years_to_plot:
if y in year_map:
try:
years_here.append((float(y), y))
except ValueError:
# Non-numeric labels – keep raw order
years_here.append((0.0, y))
# Sort years numerically where possible
years_here = [y_str for _, y_str in sorted(years_here)]
if not years_here:
continue
for stat in quantiles_to_plot:
running = None
for y_str in years_here:
mapping = year_map.get(y_str)
if not isinstance(mapping, dict):
continue
col_name = mapping.get(stat)
if not col_name or col_name not in df_wide.columns:
continue
vals = df_wide[col_name].to_numpy()
if running is None:
running = vals.copy()
else:
running = running + vals
# Overwrite in-place with cumulative values
df_wide[col_name] = running
n_rows = len(years_to_plot)
if n_rows == 0:
vlog("No years to plot after filtering.", level=0, verbose=verbose,
logger =_logger)
return
vmin, vmax = None, None
if cbar == "uniform":
all_plot_cols = [
c for p in plot_structure.values()
for y in p.values()
for c in (y.values() if isinstance(y, dict) else [y])
if c in df_wide
]
numeric_plot_cols = []
for c in all_plot_cols:
s = pd.to_numeric(df_wide[c], errors="coerce")
if s.notna().any(): # keeps real numeric cols, drops 'mm'
df_wide[c] = s
numeric_plot_cols.append(c)
if numeric_plot_cols:
vmin = min(df_wide[c].min() for c in numeric_plot_cols)
vmax = max(df_wide[c].max() for c in numeric_plot_cols)
# if cbar == 'uniform':
# all_plot_cols = [
# c for p in plot_structure.values()
# for y in p.values()
# for c in (y.values() if isinstance(y, dict) else [y])
# if c in df_wide
# ]
# if all_plot_cols:
# min_vals = [df_wide[c].dropna().min() for c in all_plot_cols]
# max_vals = [df_wide[c].dropna().max() for c in all_plot_cols]
# if any(pd.notna(v) for v in min_vals):
# vmin = min(v for v in min_vals if pd.notna(v))
# if any(pd.notna(v) for v in max_vals):
# vmax = max(v for v in max_vals if pd.notna(v))
plot_kwargs = dict(
cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off,
show_grid=show_grid,
grid_props=(grid_props or {'linestyle': ':', 'alpha': 0.7}),
verbose=verbose
)
for prefix in value_prefixes:
if stop_check and stop_check():
raise InterruptedError("Forecast visualization aborted.")
num_pred_cols = len(quantiles_to_plot)
num_actual_cols = 1 if kind == 'dual' else 0
if max_cols == 'auto':
n_cols = num_pred_cols + num_actual_cols
else:
n_cols = int(max_cols)
if n_cols == 0:
continue
fig_size = figsize or (n_cols * 3.5, n_rows * 3)
fig, axes = plt.subplots(
n_rows, n_cols, figsize=fig_size, squeeze=False,
constrained_layout=True
)
fig.suptitle(
f"Forecast Visualization for '{prefix.upper()}'",
fontsize=14, weight='bold'
)
_plot_forecast_grid(
fig, axes, df_wide, plot_structure, years_to_plot,
quantiles_to_plot, prefix, kind, spatial_cols,
s = kws.pop('s', 10), plot_kwargs=plot_kwargs,
_logger=_logger,
)
if cbar == 'uniform' and vmin is not None and vmax is not None:
fig.colorbar(
ScalarMappable(
norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap),
ax=axes.ravel().tolist(), orientation='vertical',
shrink=0.8, pad=0.04
)
if savefig:
save_figure (
fig, savefile = savefig,
save_fmts= save_fmts,
dpi=300,
bbox_inches="tight" ,
verbose = verbose,
_logger=_logger,
)
plt.close(fig)
else:
if show:
plt.show() # for notebook debugging
else:
plt.close(fig)
def _normalize_year_key(y):
"""
Normalize a year spec (int/float/str/Timestamp) to the
4-digit string keys used in plot_structure (e.g. '2023').
"""
# pandas / datetime-like
try:
if isinstance(y, (pd.Timestamp, datetime.date, datetime.datetime)):
return f"{y.year:04d}"
except Exception:
pass
# Numeric: 2023, 2023.0, 2023.0001 -> '2023'
try:
y_float = float(y)
# If it is "close" to an integer, treat as year
if np.isfinite(y_float) and abs(y_float - round(y_float)) < 1e-6:
return f"{int(round(y_float)):04d}"
except Exception:
pass
# String cases: '2023', '2023.0', '2023-01-01'
s = str(y)
# If it's something like '2023.0'
try:
y_float = float(s)
if np.isfinite(y_float) and abs(y_float - round(y_float)) < 1e-6:
return f"{int(round(y_float)):04d}"
except Exception:
pass
# Fallback: extract first 4-digit sequence if any (e.g. from '2023-01-01')
import re
m = re.search(r"(\d{4})", s)
if m:
return m.group(1)
# Last resort: return as-is
return s
[docs]
@check_non_emptiness
def plot_forecasts(
forecast_df: pd.DataFrame,
target_name: str = "target",
quantiles: Optional[List[float]] = None,
output_dim: int = 1,
kind: str = "temporal",
actual_data: Optional[pd.DataFrame] = None,
dt_col: Optional[str] = None,
actual_target_name: Optional[str] = None,
sample_ids: Optional[Union[int, List[int], str]] = "first_n",
num_samples: int = 3,
horizon_steps: Optional[Union[int, List[int]]] = 1,
spatial_cols: Optional[List[str]] = None,
max_cols: int = 2,
figsize: Tuple[float, float] = (8, 4.5),
scaler: Optional[Any] = None,
scaler_feature_names: Optional[List[str]] = None,
target_idx_in_scaler: Optional[int] = None,
titles: Optional[List[str]] = None,
cbar: Optional[str] = None,
step_names : Dict[int, Any] =None,
show_grid: bool = True,
grid_props: Optional[Dict] = None,
savefig: Optional [str]=None,
save_fmts : Optional [Union[str, List[str]]]=".png",
show: bool=True,
verbose: int = 0,
_logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
stop_check: Callable [[], bool] =None,
**plot_kwargs: Any
) -> None:
"""
Visualizes model forecasts from a structured DataFrame.
This function generates plots to visualize time series forecasts,
supporting temporal line plots for individual samples/items and
spatial scatter plots for specific forecast horizons. It can
handle point forecasts and quantile (probabilistic) forecasts,
optionally overlaying actual values for comparison. Predictions
and actuals can be inverse-transformed if a scaler is provided.
The input `forecast_df` is expected to be in a long format, typically
generated by :func:`~fusionlab.nn.utils.format_predictions_to_dataframe`,
containing 'sample_idx' and 'forecast_step' columns, along with
prediction columns (e.g., '{target_name}_pred',
'{target_name}_q50') and optionally actual value columns
(e.g., '{target_name}_actual').
Parameters
----------
forecast_df : pd.DataFrame
A pandas DataFrame containing the forecast data in long format.
Must include 'sample_idx' and 'forecast_step' columns,
along with prediction columns (and optionally actuals and
spatial coordinate columns).
target_name : str, default "target"
The base name of the target variable. This is used to
construct column names for predictions and actuals (e.g.,
"target_pred", "target_q50", "target_actual").
quantiles : List[float], optional
A list of quantiles that were predicted (e.g., `[0.1, 0.5, 0.9]`).
If provided, the function will plot the median quantile as a line
and the range between the first and last quantile as a shaded
uncertainty interval (for temporal plots). If ``None``, a point
forecast is assumed. Default is ``None``.
output_dim : int, default 1
The number of target variables (output dimensions) predicted
at each time step. If > 1, separate subplots or plot groups
may be generated for each output dimension.
kind : {'temporal', 'spatial'}, default "temporal"
The type of plot to generate:
- ``'temporal'``: Plots forecasts over the horizon for selected
samples (time series line plots).
- ``'spatial'``: Creates scatter plots of forecast values over
spatial coordinates for selected horizon steps. Requires
`spatial_cols` to be specified.
actual_data : pd.DataFrame, optional
An optional DataFrame containing the true actual values for
comparison. If `forecast_df` already contains actual value
columns (e.g., '{target_name}_actual'), this may not be needed
or can be used to supplement. If provided, `dt_col` and
`actual_target_name` might be used for alignment or direct plotting.
*Note: Current implementation primarily uses actuals from `forecast_df`.*
dt_col : str, optional
Name of the datetime column in `actual_data` or `forecast_df`
if needed for x-axis labeling in temporal plots. If `forecast_df`
uses 'forecast_step', this might be less critical unless aligning
with a true date axis from `actual_data`.
actual_target_name : str, optional
The name of the target column in `actual_data` if it differs
from `target_name`. If ``None``, `target_name` is assumed.
sample_ids : int, List[int], str, default "first_n"
Specifies which samples (based on 'sample_idx' in `forecast_df`)
to plot for `kind='temporal'`.
- If `int`: Plots the sample at that specific index.
- If `List[int]`: Plots all samples with these indices.
- If ``"first_n"``: Plots the first `num_samples`.
- If ``"all"``: Plots all unique samples (can be many plots).
num_samples : int, default 3
Number of samples to plot if `sample_ids="first_n"`.
horizon_steps : int, List[int], str, default 1
Specifies which forecast horizon step(s) to plot for
`kind='spatial'`.
- If `int`: Plots the specified single step (1-indexed).
- If `List[int]`: Plots all specified steps.
- If ``"all"`` or ``None``: Plots all unique forecast steps available.
spatial_cols : List[str], optional
Required if `kind='spatial'`. A list of two column names
from `forecast_df` representing the x and y spatial
coordinates (e.g., `['longitude', 'latitude']`).
Default is ``None``.
max_cols : int, default 2
Maximum number of subplots to arrange per row in the figure.
figsize : Tuple[float, float], default (8, 4.5)
The size `(width, height)` in inches for **each subplot**.
The total figure size will be inferred based on the number
of rows and columns of subplots.
scaler : Any, optional
A fitted scikit-learn-like scaler object (must have an
`inverse_transform` method). If provided along with
`scaler_feature_names` and `target_idx_in_scaler`,
prediction and actual columns related to `target_name`
will be inverse-transformed before plotting. Default is ``None``.
scaler_feature_names : List[str], optional
A list of all feature names (in order) that the `scaler` was
originally fit on. Required for targeted inverse transform if
`scaler` is provided. Default is ``None``.
target_idx_in_scaler : int, optional
The index of the `target_name` (or the specific output
dimension being plotted) within the `scaler_feature_names`
list. Required for targeted inverse transform if `scaler` is
provided. Default is ``None``.
titles : List[str], optional
A list of custom titles for the subplots. If provided, its
length should match the number of subplots generated.
Default is ``None`` (titles are auto-generated).
cbar : {'uniform', None}, default=None
Controls colour‑bar behaviour for *spatial* plots. Use
``'uniform'`` to compute a single global ``vmin`` / ``vmax`` and
attach one shared colour‑bar to the figure. Any other value (or
*None*) lets each subplot auto‑scale its own colour‑bar.
step_names : dict[int, Any] or None, default=None
Optional mapping that renames forecast‑step integers when building
subplot titles—e.g. ``{1: "2021", 2: "2022"}``. Keys are coerced to
``int``; values are converted to ``str``. Missing steps fall back
to *default_name* (``""`` is the step number itself).
show_grid : bool, default=True
Toggle the display of ``ax.grid`` in both temporal and spatial
plots. Set to *False* for a cleaner look when overlaying graphics
on a map.
grid_props : dict or None, default=None
Keyword arguments forwarded to :pyfunc:`matplotlib.axes.Axes.grid`
(e.g. ``{"linestyle": ":", "alpha": 0.6}``). Ignored when
*show_grid* is *False*.
verbose : int, default 0
Verbosity level for logging during plot generation.
- ``0``: Silent.
- ``1`` or higher: Print informational messages.
**plot_kwargs : Any
Additional keyword arguments to pass to the underlying
Matplotlib plotting functions (e.g., `ax.plot`, `ax.scatter`,
`ax.fill_between`). Can be used to customize line styles,
colors, marker sizes, etc. For example,
`median_plot_kwargs={'color': 'blue'}`,
`fill_between_kwargs={'color': 'lightblue'}`.
Returns
-------
None
This function directly generates and shows/saves plots using
Matplotlib and does not return any value.
Raises
------
TypeError
If `forecast_df` is not a pandas DataFrame.
ValueError
If essential columns like 'sample_idx' or 'forecast_step'
are missing from `forecast_df`.
If `kind='spatial'` is chosen but `spatial_cols` are not
provided or not found in `forecast_df`.
If an unsupported `kind` is specified.
See Also
--------
fusionlab.nn.utils.format_predictions_to_dataframe :
Utility to generate the `forecast_df` expected by this function.
matplotlib.pyplot.plot : Underlying plotting function for temporal forecasts.
matplotlib.pyplot.scatter : Underlying plotting function for spatial forecasts.
Examples
--------
>>> from fusionlab.nn.utils import format_predictions_to_dataframe
>>> from fusionlab.plot.forecast import plot_forecasts
>>> import pandas as pd
>>> import numpy as np
>>> # Assume preds_point (B,H,O) and preds_quant (B,H,Q) are available
>>> B, H, O, Q_len = 4, 3, 1, 3
>>> preds_point = np.random.rand(B, H, O)
>>> preds_quant = np.random.rand(B, H, Q_len)
>>> y_true_seq = np.random.rand(B, H, O)
>>> quantiles_list = [0.1, 0.5, 0.9]
>>> # Create DataFrames using format_predictions_to_dataframe
>>> df_point = format_predictions_to_dataframe(
... predictions=preds_point, y_true_sequences=y_true_seq,
... target_name="value", forecast_horizon=H, output_dim=O
... )
>>> df_quant = format_predictions_to_dataframe(
... predictions=preds_quant, y_true_sequences=y_true_seq,
... target_name="value", quantiles=quantiles_list,
... forecast_horizon=H, output_dim=O
... )
>>> # Example 1: Plot temporal point forecast for first sample
>>> # plot_forecasts(df_point, target_name="value", sample_ids=0)
>>> # Example 2: Plot temporal quantile forecast for first 2 samples
>>> # plot_forecasts(df_quant, target_name="value",
... # quantiles=quantiles_list, sample_ids="first_n",
... # num_samples=2, max_cols=1)
>>> # Example 3: Spatial plot (requires spatial_cols in df_quant)
>>> # Assume df_quant has 'lon' and 'lat' columns
>>> # df_quant['lon'] = np.random.rand(len(df_quant)) * 10
>>> # df_quant['lat'] = np.random.rand(len(df_quant)) * 5
>>> # plot_forecasts(df_quant, target_name="value",
... # quantiles=quantiles_list, kind='spatial',
... # horizon_steps=1, spatial_cols=['lon', 'lat'])
"""
vlog(f"Starting forecast visualization (kind='{kind}')...",
level=3, verbose=verbose, logger=_logger)
if not isinstance(forecast_df, pd.DataFrame):
raise TypeError("`forecast_df` must be a pandas DataFrame, typically "
"the output of `format_predictions_to_dataframe`.")
if 'sample_idx' not in forecast_df.columns or \
'forecast_step' not in forecast_df.columns:
raise ValueError(
"`forecast_df` must contain 'sample_idx' and 'forecast_step' "
"columns."
)
if stop_check and stop_check():
raise InterruptedError("Plotting aborted.")
# --- Data Preparation & Inverse Transform ---
# Work on a copy to avoid modifying the original DataFrame
df_to_plot = forecast_df.copy()
actual_col_names_in_df = [] # To store names of actual columns in df_to_plot
pred_col_names_in_df = [] # To store names of prediction columns
# Identify prediction and actual columns based on target_name and quantiles
base_pred_name = target_name
base_actual_name = actual_target_name if actual_target_name \
else target_name
if quantiles:
# Sort quantiles to ensure consistent order for plotting (e.g., low, mid, high)
sorted_quantiles = sorted([float(q) for q in quantiles])
for o_idx in range(output_dim):
for q_val in sorted_quantiles:
q_suffix = f"_q{int(q_val*100)}"
col_name = f"{base_pred_name}"
if output_dim > 1: col_name += f"_{o_idx}"
col_name += q_suffix
if col_name in df_to_plot.columns:
pred_col_names_in_df.append(col_name)
# Actual column for this output dimension
actual_col = f"{base_actual_name}"
if output_dim > 1: actual_col += f"_{o_idx}"
actual_col += "_actual"
if actual_col in df_to_plot.columns:
actual_col_names_in_df.append(actual_col)
else: # Point forecast
for o_idx in range(output_dim):
col_name = f"{base_pred_name}"
if output_dim > 1: col_name += f"_{o_idx}"
col_name += "_pred"
if col_name in df_to_plot.columns:
pred_col_names_in_df.append(col_name)
actual_col = f"{base_actual_name}"
if output_dim > 1: actual_col += f"_{o_idx}"
actual_col += "_actual"
if actual_col in df_to_plot.columns:
actual_col_names_in_df.append(actual_col)
if not pred_col_names_in_df:
warnings.warn("No prediction columns found in `forecast_df` "
"based on `target_name` and `quantiles`.")
# Attempt to find any column ending with _pred or _qXX
pred_cols_found = [c for c in df_to_plot.columns if "_pred" in c or "_q" in c]
if pred_cols_found:
pred_col_names_in_df = pred_cols_found
vlog(f" Auto-detected prediction columns: {pred_col_names_in_df}",
level=2, verbose=verbose,logger=_logger)
else:
vlog(" No prediction columns could be auto-detected. "
"Plotting may be limited.", level=2, verbose=verbose,
logger=_logger)
# Apply inverse transform if scaler is provided
if scaler is not None:
if scaler_feature_names is None or target_idx_in_scaler is None:
warnings.warn(
"Scaler provided, but `scaler_feature_names` or "
"`target_idx_in_scaler` is missing. Inverse transform "
"cannot be applied correctly to specific target columns.",
UserWarning
)
else:
vlog(" Applying inverse transformation using provided scaler...",
level=4, verbose=verbose, logger=_logger)
cols_to_inverse_transform = pred_col_names_in_df + actual_col_names_in_df
dummy_array_shape = (len(df_to_plot), len(scaler_feature_names))
for col_to_inv in cols_to_inverse_transform:
if col_to_inv in df_to_plot.columns:
try:
dummy = np.zeros(dummy_array_shape)
dummy[:, target_idx_in_scaler] = df_to_plot[col_to_inv]
df_to_plot[col_to_inv] = scaler.inverse_transform(
dummy
)[:, target_idx_in_scaler]
except Exception as e:
warnings.warn(
f"Failed to inverse transform column '{col_to_inv}'. "
f"Plotting scaled value. Error: {e}"
)
vlog(" Inverse transformation applied.",
level=5, verbose=verbose, logger=_logger
)
if stop_check and stop_check():
raise InterruptedError("Quantile plot configuration aborted.")
# --- Select Samples/Items to Plot ---
unique_sample_ids = df_to_plot['sample_idx'].unique()
selected_ids_for_plot = []
if isinstance(sample_ids, str):
if sample_ids.lower() == "all":
selected_ids_for_plot = unique_sample_ids
elif sample_ids.lower() == "first_n":
selected_ids_for_plot = unique_sample_ids[:num_samples]
else:
warnings.warn(f"Unknown string for `sample_ids`: "
f"'{sample_ids}'. Plotting first sample.")
selected_ids_for_plot = unique_sample_ids[:1]
elif isinstance(sample_ids, int):
if sample_ids < len(unique_sample_ids):
selected_ids_for_plot = [unique_sample_ids[sample_ids]]
else:
warnings.warn(f"sample_idx {sample_ids} out of range. "
"Plotting first sample.")
selected_ids_for_plot = unique_sample_ids[:1]
elif isinstance(sample_ids, list):
selected_ids_for_plot = [
sid for sid in sample_ids if sid in unique_sample_ids
]
if len(selected_ids_for_plot) ==0:
vlog("No valid sample_idx found to plot. Defaulting to first available.",
level=2, verbose=verbose, logger=_logger)
selected_ids_for_plot = unique_sample_ids[:1]
if len(selected_ids_for_plot) ==0 : # check if is empty
vlog ("No sample to plot. Returning ...", level =1, verbose=verbose)
return
vlog(f" Plotting for sample_idx: {selected_ids_for_plot}",
level=4, verbose=verbose, logger=_logger)
cmap = plot_kwargs.get('cmap', 'viridis')
# --- Plotting Logic ---
if kind == "temporal":
num_plots = len(selected_ids_for_plot) * output_dim
if num_plots == 0:
vlog("No data to plot for temporal type.", level=2,
verbose=verbose, logger=_logger)
return
n_cols_plot = min(max_cols, num_plots)
n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot
fig, axes = plt.subplots(
n_rows_plot, n_cols_plot,
figsize=(n_cols_plot * figsize[0],
n_rows_plot * figsize[1]),
squeeze=False # Always return 2D array for axes
)
axes_flat = axes.flatten()
plot_idx = 0
for s_idx in selected_ids_for_plot:
sample_df = df_to_plot[df_to_plot['sample_idx'] == s_idx]
if sample_df.empty:
continue
for o_idx in range(output_dim):
if plot_idx >= len(axes_flat): break # Should not happen
ax = axes_flat[plot_idx]
if stop_check and stop_check():
raise InterruptedError("Temporal plot aborted.")
title = f"Sample ID: {s_idx}"
if output_dim > 1:
title += f", Target Dim: {o_idx}"
if titles and plot_idx < len(titles):
title = titles[plot_idx]
ax.set_title(title)
# Plot actuals if available
actual_col = f"{base_actual_name}"
if output_dim > 1: actual_col += f"_{o_idx}"
actual_col += "_actual"
if actual_col in sample_df.columns:
ax.plot(
sample_df['forecast_step'], sample_df[actual_col],
label=f'Actual {target_name}'
f'{f"_{o_idx}" if output_dim > 1 else ""}',
marker='o', linestyle='--'
)
# Plot predictions
if quantiles:
median_q_val = 0.5
if 0.5 not in quantiles: # Find closest if 0.5 not present
median_q_val = quantiles[len(quantiles) // 2]
median_col = f"{base_pred_name}"
if output_dim > 1: median_col += f"_{o_idx}"
median_col += f"_q{int(median_q_val*100)}"
lower_q_col = f"{base_pred_name}"
if output_dim > 1: lower_q_col += f"_{o_idx}"
lower_q_col += f"_q{int(sorted_quantiles[0]*100)}"
upper_q_col = f"{base_pred_name}"
if output_dim > 1: upper_q_col += f"_{o_idx}"
upper_q_col += f"_q{int(sorted_quantiles[-1]*100)}"
if median_col in sample_df.columns:
ax.plot(
sample_df['forecast_step'], sample_df[median_col],
label=f'Median (q{int(median_q_val*100)})',
marker='x',
**plot_kwargs.get("median_plot_kwargs", {})
)
if lower_q_col in sample_df.columns and \
upper_q_col in sample_df.columns:
ax.fill_between(
sample_df['forecast_step'],
sample_df[lower_q_col],
sample_df[upper_q_col],
color='gray', alpha=0.3,
label=f'Interval (q{int(sorted_quantiles[0]*100)}'
f'-q{int(sorted_quantiles[-1]*100)})',
**plot_kwargs.get("fill_between_kwargs", {})
)
else: # Point forecast
pred_col = f"{base_pred_name}"
if output_dim > 1: pred_col += f"_{o_idx}"
pred_col += "_pred"
if pred_col in sample_df.columns:
ax.plot(
sample_df['forecast_step'], sample_df[pred_col],
label=f'Predicted {target_name}'
f'{f"_{o_idx}" if output_dim > 1 else ""}',
marker='x',
**plot_kwargs.get("point_plot_kwargs", {})
)
ax.set_xlabel("Forecast Step into Horizon")
ax.set_ylabel(
f"{target_name}{f' (Dim {o_idx})' if output_dim > 1 else ''}"
)
ax.legend()
ax.grid(True)
plot_idx += 1
# Hide unused subplots
for i in range(plot_idx, len(axes_flat)):
axes_flat[i].set_visible(False)
fig.tight_layout()
# plt.show()
elif kind == "spatial":
spatial_x_col = None
spatial_y_col =None
if spatial_cols is not None:
spatial_cols = columns_manager (spatial_cols, empty_as_none=False)
if len(spatial_cols )!=2:
raise ValueError(
"Spatial_cols need exactly two columns (e.g., longitude,"
f" latitude ). Got {spatial_cols}")
spatial_x_col , spatial_y_col = spatial_cols
if spatial_x_col is None or spatial_y_col is None:
raise ValueError(
"`spatial_x_col` and `spatial_y_col` must be provided "
"for `kind='spatial'`."
)
if spatial_x_col not in df_to_plot.columns or \
spatial_y_col not in df_to_plot.columns:
raise ValueError(
f"Spatial columns '{spatial_x_col}' or '{spatial_y_col}' "
"not found in forecast_df."
)
steps_to_plot = []
if isinstance(horizon_steps, int):
steps_to_plot = [horizon_steps]
elif isinstance(horizon_steps, list):
steps_to_plot = horizon_steps
elif horizon_steps is None or \
str(horizon_steps).lower() == "all":
# Plot all unique forecast steps present in the data
steps_to_plot = sorted(df_to_plot['forecast_step'].unique())
else:
raise ValueError("Invalid `horizon_steps`.")
num_plots = len(steps_to_plot) * output_dim
if num_plots == 0:
vlog("No data/steps to plot for spatial type.",
level=2, verbose=verbose, logger=_logger)
return
vmin, vmax = None, None
if cbar=="uniform":
vlog("Calculating uniform color scale for all subplots...",
level=2, verbose=verbose, logger=_logger)
# Filter DataFrame to only the steps being plotted
df_for_scaling = df_to_plot[
df_to_plot['forecast_step'].isin(steps_to_plot)
]
# Identify all columns that will be used for coloring
cols_to_scan = []
if quantiles:
median_q = 0.5
if 0.5 not in quantiles:
median_q = sorted(quantiles)[len(quantiles) // 2]
for o_idx in range(output_dim):
col_name = f"{base_pred_name}"
if output_dim > 1: col_name += f"_{o_idx}"
col_name += f"_q{int(median_q*100)}"
cols_to_scan.append(col_name)
else: # Point forecasts
for o_idx in range(output_dim):
col_name = f"{base_pred_name}"
if output_dim > 1: col_name += f"_{o_idx}"
col_name += "_pred"
cols_to_scan.append(col_name)
existing_color_cols = [
c for c in cols_to_scan if c in df_for_scaling.columns
]
if existing_color_cols:
# Calculate min/max on the filtered DataFrame
min_vals = [
df_for_scaling[c].dropna().min()
for c in existing_color_cols]
max_vals = [df_for_scaling[c].dropna().max()
for c in existing_color_cols]
if any(pd.notna(v) for v in min_vals):
vmin = min(v for v in min_vals if pd.notna(v))
if any(pd.notna(v) for v in max_vals):
vmax = max(v for v in max_vals if pd.notna(v))
if vmin is not None and vmax is not None:
vlog(f"Uniform color range set: vmin={vmin:.2f}, vmax={vmax:.2f}",
level=3, verbose=verbose, logger=_logger)
n_cols_plot = min(max_cols, num_plots)
n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot
fig, axes = plt.subplots(
n_rows_plot, n_cols_plot,
figsize=(n_cols_plot * figsize[0],
n_rows_plot * figsize[1]),
squeeze=False
)
axes_flat = axes.flatten()
plot_idx = 0
step_names = get_step_names(
steps_to_plot, step_names =step_names,
default_name= "{}"
)
grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7}
for step in steps_to_plot:
step_df = df_to_plot[df_to_plot['forecast_step'] == step]
if stop_check and stop_check():
raise InterruptedError("Spatial plot aborted.")
if step_df.empty:
continue
for o_idx in range(output_dim):
if plot_idx >= len(axes_flat): break
ax = axes_flat[plot_idx]
# Determine column to use for color intensity
color_col = None
plot_title_suffix = ""
if quantiles: # Use median for color
median_q_val = 0.5
if 0.5 not in quantiles:
median_q_val = quantiles[len(quantiles) // 2]
color_col = f"{base_pred_name}"
if output_dim > 1: color_col += f"_{o_idx}"
color_col += f"_q{int(median_q_val*100)}"
plot_title_suffix = f" (Median q{int(median_q_val*100)})"
else: # Point forecast
color_col = f"{base_pred_name}"
if output_dim > 1: color_col += f"_{o_idx}"
color_col += "_pred"
if color_col not in step_df.columns:
warnings.warn(f"Color column '{color_col}' not found "
"for spatial plot. Skipping subplot.")
plot_idx +=1 # Increment to avoid reusing subplot
continue
scatter_data = step_df.dropna(
subset=[spatial_x_col, spatial_y_col, color_col]
)
if scatter_data.empty:
warnings.warn(f"No valid data to plot for step {step}, "
f"output_dim {o_idx} after dropna. Skipping.")
plot_idx += 1
continue
# Normalize color data for better visualization
if cbar !='uniform':
norm = mcolors.Normalize(
vmin=scatter_data[color_col].min(),
vmax=scatter_data[color_col].max()
)
else:
norm=None # norm will be applied for uniform scale
step_name = step_names.get(step, '{}') # for consistency
sc = ax.scatter(
scatter_data[spatial_x_col],
scatter_data[spatial_y_col],
c=scatter_data[color_col],
cmap=cmap,
norm=norm,
s=plot_kwargs.get('s', 50), # Marker size
alpha=plot_kwargs.get('alpha', 0.7)
)
if cbar !='uniform': # use default behavior
fig.colorbar(sc, ax=ax, label=f"{target_name}{plot_title_suffix}")
title = step_name.format(f"Forecast Step: {step}")
if output_dim > 1:
title += f", Target Dim: {o_idx}"
if titles and plot_idx < len(titles):
title = titles[plot_idx]
ax.set_title(title)
ax.set_xlabel(spatial_x_col)
ax.set_ylabel(spatial_y_col)
if show_grid:
ax.grid(True, **grid_props)
else:
ax.grid(False)
plot_idx += 1
for i in range(plot_idx, len(axes_flat)):
axes_flat[i].set_visible(False)
fig.tight_layout()
if cbar=='uniform':
# This creates one colorbar for the whole
# figure if uniform scale was used.
# Adjust the main plot area to make space for the colorbar
# This leaves 10% space on the right for the cbar
fig.subplots_adjust(right=0.90)
# Create a new axis for the colorbar
# [left, bottom, width, height] in figure coordinates
cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
mappable = ScalarMappable(norm=norm, cmap=cmap)
# Add colorbar to the figure
fig.colorbar(
mappable,
cax=cax,
# ax=axes.ravel().tolist(), # Associate with all axes
label=f"{target_name} {plot_title_suffix.strip()}",
orientation='vertical',
shrink=0.8,
pad=0.04
)
else:
raise ValueError(
f"Unsupported `kind`: '{kind}'. "
"Choose 'temporal' or 'spatial'."
)
if stop_check and stop_check():
raise InterruptedError("Plot forecast aborted.")
vlog("Forecast visualization complete.", level=3,
verbose=verbose, logger=_logger)
# 4. Save figure to disk if requested.
if savefig:
save_figure (
fig, savefile = savefig,
save_fmts= save_fmts,
dpi=300, bbox_inches="tight" ,
_logger = _logger,
)
plt.close(fig)
else:
if show:
plt.show() # for notebook debugging
else:
plt.close(fig)
[docs]
@check_non_emptiness
def visualize_forecasts(
forecast_df,
dt_col,
tname,
test_data=None,
eval_periods=None,
mode="quantile",
kind="spatial",
actual_name=None,
x=None,
y=None,
cmap="coolwarm",
max_cols=3,
axis="on",
s=2,
show_grid=True,
grid_props=None,
savefig=None,
save_fmts =None,
verbose=1,
**kw
):
r"""
Visualize forecast results and actual test data for one or more
evaluation periods.
The function plots a grid of scatter plots comparing actual values
with forecasted predictions. Each evaluation period yields two plots:
one for actual values and one for predicted values. If multiple
evaluation periods are provided, the grid layout wraps after
``max_cols`` columns.
.. math::
\hat{y}_{t+i} = f\Bigl(
X_{\text{static}},\;X_{\text{dynamic}},\;
X_{\text{future}}\Bigr)
for :math:`i = 1, \dots, N`, where :math:`N` is the forecast horizon.
Parameters
----------
forecast_df : pandas.DataFrame
DataFrame containing forecast results with a time column,
spatial coordinates, and prediction columns.
dt_col : str
Name of the time column used to filter forecast results (e.g.
``"year"``).
tname : str
Target variable name used to construct forecast columns (e.g.
``"subsidence"``). This argument is required.
eval_periods : scalar or list, optional
Evaluation period(s) used to select forecast results. If set to
``None``, the function selects up to three unique periods from
``test_data[dt_col]``.
mode : str, optional
Forecast mode. Must be either ``"quantile"`` or ``"point"``.
Default is ``"quantile"``.
kind : str, optional
Type of visualization. If ``"spatial"``, spatial columns are
required; otherwise, the provided `x` and `y` columns are used.
x : str, optional
Column name for the x-axis. For non-spatial plots, this must be
provided or will be inferred via ``assert_xy_in``.
y : str, optional
Column name for the y-axis. For non-spatial plots, this must be
provided or will be inferred via ``assert_xy_in``.
cmap : str, optional
Colormap used for scatter plots. Default is ``"coolwarm"``.
max_cols : int, optional
Maximum number of evaluation periods to plot per row. If the
number of periods exceeds ``max_cols``, a new row is started.
axis: str, optional,
Wether to keep the axis of set it to False.
show_grid: bool, default=True,
Visualize the grid
grid_props: dict, optional
Grid properties for visualizations. If none the properties is
infered as ``{"linestyle":":", 'alpha':0.7}``.
verbose : int, optional
Verbosity level. Controls the amount of output printed.
Returns
-------
None
The function displays the visualization plot.
Examples
--------
Example 1: **Spatial Visualization**
In this example, we visualize the forecasted and actual values of the
**subsidence** target variable, using **longitude** and **latitude**
for the spatial coordinates. We visualize the results for two
evaluation periods (2023 and 2024), using **quantile** mode for the forecast.
>>> from fusionlab.plot.forecast import visualize_forecasts
>>> forecast_results = pd.DataFrame({
>>> 'longitude': [-103.808151, -103.808151, -103.808151],
>>> 'latitude': [0.473152, 0.473152, 0.473152],
>>> 'subsidence_q50': [0.3, 0.4, 0.5],
>>> 'subsidence': [0.35, 0.42, 0.49],
>>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03']
>>> })
>>> test_data = pd.DataFrame({
>>> 'longitude': [-103.808151, -103.808151, -103.808151],
>>> 'latitude': [0.473152, 0.473152, 0.473152],
>>> 'subsidence': [0.35, 0.41, 0.49],
>>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03']
>>> })
>>> visualize_forecasts(
>>> forecast_df=forecast_results,
>>> test_data=test_data,
>>> dt_col="date",
>>> tname="subsidence",
>>> eval_periods=[2023, 2024],
>>> mode="quantile",
>>> kind="spatial",
>>> cmap="coolwarm",
>>> max_cols=2,
>>> verbose=1
>>> )
Example 2: **Non-Spatial Visualization**
In this example, we visualize the forecasted and actual values of the
**subsidence** target variable in a **non-spatial** context. The columns
`longitude` and `latitude` are still provided but used for non-spatial
x and y axes. Evaluation is for 2023.
>>> from fusionlab.plot.forecast import visualize_forecasts
>>> forecast_results = pd.DataFrame({
>>> 'longitude': [-103.808151, -103.808151, -103.808151],
>>> 'latitude': [0.473152, 0.473152, 0.473152],
>>> 'subsidence_pred': [0.35, 0.41, 0.48],
>>> 'subsidence': [0.36, 0.43, 0.49],
>>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03']
>>> })
>>> test_data = pd.DataFrame({
>>> 'longitude': [-103.808151, -103.808151, -103.808151],
>>> 'latitude': [0.473152, 0.473152, 0.473152],
>>> 'subsidence': [0.36, 0.42, 0.50],
>>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03']
>>> })
>>> forecast_df_point = visualize_forecasts(
>>> forecast_df=forecast_results,
>>> test_data=test_data,
>>> dt_col="date",
>>> tname="subsidence",
>>> eval_periods=[2023],
>>> mode="point",
>>> kind="non-spatial",
>>> x="longitude",
>>> y="latitude",
>>> cmap="viridis",
>>> max_cols=1,
>>> axis="off",
>>> show_grid=True,
>>> grid_props={"linestyle": "--", "alpha": 0.5},
>>> verbose=2
>>> )
Notes
-----
- In ``quantile`` mode, the function uses the column
``<tname>_q50`` for visualization.
- In ``point`` mode, the column ``<tname>_pred`` is used.
- For spatial visualizations, if ``x`` and ``y`` are not provided,
they default to ``"longitude"`` and ``"latitude"``.
- The evaluation period(s) are determined by filtering
``forecast_df[dt_col] == <eval_period>``.
- Use ``assert_xy_in`` to validate that the x and y columns exist in
the provided DataFrames.
See Also
--------
generate_forecast : Function to generate forecast results.
coverage_score : Function to compute the coverage score.
References
----------
.. [1] Kouadio L. et al., "Gofast Forecasting Model", Journal of
Advanced Forecasting, 2025. (In review)
"""
# ******************************************************
from ..utils.ts_utils import filter_by_period
# *********************************************************
# Check that forecast_df is a valid DataFrame
is_frame (
forecast_df,
df_only=True,
objname="Forecast data",
error="raise"
)
if eval_periods is None:
unique_periods = sorted(forecast_df[dt_col].unique())
if verbose:
print("No eval_period provided; using up to three unique " +
"periods from forecast data.")
eval_periods = unique_periods[:3]
eval_periods = columns_manager(eval_periods, to_string=True )
# Check if test_data is provided, else set it to None
if test_data is not None:
is_frame (
test_data,
df_only=True,
objname="Test data",
error="raise"
)
# filterby periods ensure Ensure dt_col is in Pandas datetime format
test_data =filter_by_period (test_data, eval_periods, dt_col)
forecast_df =filter_by_period (forecast_df, eval_periods, dt_col)
# Convert eval_periods to Pandas datetime64[ns] format
# # Ensure dtype match before filtering
eval_periods= forecast_df[dt_col].astype(str).unique()
# Determine x and y columns for spatial or non-spatial visualization
if kind == "spatial":
if x is None and y is None:
x, y = "longitude", "latitude"
check_spatial_columns(forecast_df, spatial_cols=(x, y ))
x, y = assert_xy_in(x, y, data=forecast_df, asarray=False)
else:
if x is None or y is None:
raise ValueError("For non-spatial kind, both x and y must be provided.")
x, y = assert_xy_in(x, y, data=forecast_df, asarray=False)
# Set prediction column based on forecast mode
if mode == "quantile":
pred_col = f"{tname}_q50"
pred_label = f"Predicted {tname} (q50)"
elif mode == "point":
pred_col = f"{tname}_pred"
pred_label = f"Predicted {tname}"
else:
raise ValueError("Mode must be either 'quantile' or 'point'.")
# XXX # restore back to origin_dtype before
# Loop over evaluation periods and plot
df_actual = test_data if test_data is not None else forecast_df
actual_name = get_actual_column_name (
df_actual, tname, actual_name=actual_name ,
default_to='tname',
)
# Compute global min-max for color scale
# for all plot.
vmin = forecast_df[pred_col].min()
vmax = forecast_df[pred_col].max()
if test_data is not None and actual_name in test_data.columns:
vmin = min(vmin, test_data[actual_name].min())
vmax = max(vmax, test_data[actual_name].max())
# Determine common periods in both forecast_df and test_data (if available)
if test_data is not None:
available_periods = is_in_if (
forecast_df[dt_col].astype(str).unique(),
test_data[dt_col].astype(str).unique(),
return_intersect=True,
)
# available_periods = sorted(set(forecast_df[dt_col]) & set(test_data[dt_col]))
else:
# sorted(forecast_df[dt_col].astype(str).unique())
available_periods = eval_periods
# Ensure the eval_periods only contain periods available in the data
eval_periods = [p for p in eval_periods if p in available_periods]
if len(eval_periods) == 0:
raise ValueError(
"[ERROR] No valid evaluation periods found in forecast or test data.")
# Compute subplot grid dimensions
n_periods = len(eval_periods)
n_cols = min(n_periods, max_cols)
n_rows = int(np.ceil(n_periods / max_cols))
# Two rows per evaluation period if
# test_data is passed or is not empty
total_rows = n_rows * 2 if test_data is not None else n_rows
# Create subplot grid
fig, axes = plt.subplots(
total_rows, n_cols, figsize=(5 * n_cols, 4 * total_rows)
)
# Ensure `axes` is a 2D array for consistent indexing
if total_rows == 1 and n_cols == 1:
pass # axes = np.array([[axes]])
elif total_rows == 1:
axes = np.array([axes])
elif n_cols == 1:
axes = axes.reshape(total_rows, 1)
for idx, period in enumerate(eval_periods):
# Try to reconvert
col_idx = idx % n_cols
row_idx = (idx // n_cols) * 2 if test_data is not None else (idx // n_cols)
# Filter data for the current period using 'isin' for robustness
forecast_subset = forecast_df[forecast_df[dt_col].isin([period])]
if test_data is not None:
test_subset = test_data[test_data[dt_col].isin([period])]
# test_subset =filter_by_period (test_data, period, dt_col)
else:
test_subset = forecast_subset # If no test_data, use forecast_df itself
if forecast_subset.empty or test_subset.empty:
if verbose:
print(f"[WARNING] No data for period {period}; skipping.")
continue
# Plot actual values
if test_data is not None:
ax_actual = axes[row_idx, col_idx]
sc_actual = ax_actual.scatter(
test_subset[x.name],
test_subset[y.name],
c=test_subset[actual_name],
cmap=cmap,
alpha=0.7,
edgecolors='k',
s=s,
vmin=vmin,
vmax=vmax,
**kw
)
ax_actual.set_title(f"Actual {tname.capitalize()} ({period})")
ax_actual.set_xlabel(x.name.capitalize())
ax_actual.set_ylabel(y.name.capitalize())
if axis == "off":
ax_actual.set_axis_off()
else:
ax_actual.set_axis_on()
fig.colorbar(sc_actual, ax=ax_actual, label=tname.capitalize())
if show_grid:
if grid_props is None:
grid_props = {"linestyle": ":", 'alpha': 0.7}
ax_actual.grid(True, **grid_props)
# Plot predicted values
ax_pred = axes[row_idx + 1, col_idx
] if test_data is not None else axes[row_idx, col_idx]
# ax_pred = axes[row_idx + 1, col_idx]
sc_pred = ax_pred.scatter(
forecast_subset[x.name],
forecast_subset[y.name],
c=forecast_subset[pred_col],
cmap=cmap,
alpha=0.7,
edgecolors='k',
s=s,
vmin=vmin, # Apply global min
vmax=vmax, # Apply global max
**kw
)
ax_pred.set_title(f"{pred_label} ({period})")
ax_pred.set_xlabel(x.name.capitalize())
ax_pred.set_ylabel(y.name.capitalize())
if axis == "off":
ax_pred.set_axis_off()
else:
ax_pred.set_axis_on()
if show_grid:
if grid_props is None:
grid_props = {"linestyle": ":", 'alpha': 0.7}
ax_pred.grid(True, **grid_props)
fig.colorbar(sc_pred, ax=ax_pred, label=pred_label)
# 4. Save figure to disk if requested.
plt.tight_layout()
if savefig:
save_figure (
fig, savefile = savefig,
save_fmts= save_fmts,
dpi=300, bbox_inches="tight"
)
plt.close(fig)
else:
plt.show()
def _get_metrics_from_cols(
columns: List[str], prefixes: List[str]
) -> List[str]:
"""Extracts metric suffixes (e.g., q10, actual) from column names."""
metrics = set()
for col in columns:
for p in prefixes:
if col.startswith(p + '_'):
metrics.add(col[len(p)+1:])
return sorted(list(metrics))
def _parse_wide_df_columns(
df_wide: pd.DataFrame,
value_prefixes: List[str]
) -> Dict[str, Dict[str, Dict[str, str]]]:
"""Parses wide-format columns into a structured dictionary."""
plot_structure = {prefix: {} for prefix in value_prefixes}
META_SUFFIXES = {"unit", "units", "uom"} # add more if needed
# Regex for columns with years, e.g., GWL_2022_q10 or GWL_2022_actual
pattern_year = re.compile(
r'(' + '|'.join(re.escape(p) for p in value_prefixes) +
r')_(\d{4})_?(.*)'
)
# Regex for columns without years, e.g., GWL_q10 or GWL_pred
pattern_no_year = re.compile(
r'(' + '|'.join(re.escape(p) for p in value_prefixes) +
r')_([a-zA-Z].*)'
)
for col in df_wide.columns:
match_year = pattern_year.match(col)
match_no_year = pattern_no_year.match(col)
if match_year:
prefix, year, suffix = match_year.groups()
# If suffix is empty, it's a point prediction
suffix = suffix or 'pred'
if suffix in META_SUFFIXES:
continue
if year not in plot_structure[prefix]:
plot_structure[prefix][year] = {}
plot_structure[prefix][year][suffix] = col
elif match_no_year:
# Handles columns like 'subsidence_q10',
# 'subsidence_actual'
prefix, suffix = match_no_year.groups()
if suffix in META_SUFFIXES:
continue
if "static" not in plot_structure[prefix]:
plot_structure[prefix]["static"] = {}
plot_structure[prefix]["static"][suffix] = col
elif any(col == p for p in value_prefixes):
# Handles base prefix column e.g. 'subsidence'
if "static" not in plot_structure[col]:
plot_structure[col]["static"] ={}
plot_structure[col]['static']['pred'] = col
return plot_structure
def _plot_spatial_subplot(
ax, df, x_col, y_col, c_col, s=10, **kwargs
):
"""Helper to create a single spatial subplot."""
title = kwargs.get('title', '')
if c_col is None or c_col not in df.columns:
ax.set_title(
f"{title}\n(Data not found)",
fontsize=10,
color='red',
)
ax.axis('off')
return None
plot_df = df[[x_col, y_col, c_col]].dropna()
if plot_df.empty:
ax.set_title(
f"{title}\n(No valid data)",
fontsize=10,
)
ax.axis('off')
return None
cmap = kwargs.get('cmap')
vmin = kwargs.get('vmin')
vmax = kwargs.get('vmax')
axis_off = kwargs.get('axis_off')
show_grid = kwargs.get('show_grid')
grid_props = kwargs.get('grid_props')
spatial_mode = kwargs.get('spatial_mode', 'scatter')
hexbin_gridsize = kwargs.get('hexbin_gridsize', 40)
if spatial_mode == 'hexbin':
artist = ax.hexbin(
plot_df[x_col].to_numpy(),
plot_df[y_col].to_numpy(),
C=plot_df[c_col].to_numpy(),
gridsize=hexbin_gridsize,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
else:
artist = ax.scatter(
plot_df[x_col],
plot_df[y_col],
c=plot_df[c_col],
cmap=cmap,
vmin=vmin,
vmax=vmax,
s=s,
edgecolors='k',
linewidths=0.1,
alpha=0.8,
)
ax.set_title(title, fontsize=10)
if axis_off:
ax.axis('off')
else:
ax.tick_params(
axis='both',
which='major',
labelsize=8,
)
if show_grid:
ax.grid(**grid_props)
return artist
def _plot_temporal_subplot(ax, df, value_col, title, **kwargs):
"""Helper to create a single line plot (fallback)."""
if value_col is None or value_col not in df.columns:
ax.set_title(f"{title}\n(Data not found)", fontsize=10, color='red')
ax.axis('off')
return
plot_df = df[[value_col]].dropna()
if plot_df.empty:
ax.set_title(f"{title}\n(No valid data)", fontsize=10)
ax.axis('off')
return
ax.plot(plot_df.index, plot_df[value_col])
ax.set_title(title, fontsize=10)
if not kwargs.get('axis_off'):
ax.tick_params(axis='both', which='major', labelsize=8)
if kwargs.get('show_grid'):
ax.grid(**kwargs.get('grid_props'))
def _plot_forecast_grid(
fig, axes, df_wide, plot_structure, years_to_plot,
quantiles_to_plot, prefix, kind, spatial_cols, s, plot_kwargs,
_logger
):
# Check if spatial coordinates are available for plotting.
has_spatial_coords = all(
c in df_wide.columns for c in spatial_cols
)
# Log a fallback message if spatial coordinates are not found.
if not has_spatial_coords:
vlog(
f"Spatial columns {spatial_cols} not found. "
"Falling back to temporal line plots.",
level=1, verbose=plot_kwargs.get('verbose', 0),
logger=_logger
)
# Get the last known static actual column for fallback.
last_known_actual_col = plot_structure[prefix].get(
"static", {}).get("actual")
# Iterate through each year to create a row of plots.
for row_idx, year in enumerate(years_to_plot):
ax_row = axes[row_idx]
col_idx = 0
# Plot actual data if 'dual' mode is enabled.
if kind == 'dual':
actual_col = plot_structure[prefix].get(
year, {}).get('actual', last_known_actual_col)
if actual_col:
last_known_actual_col = actual_col
title = f"Actual ({year})"
plot_kwargs['title'] = title
# Choose plotting function based on coordinate availability.
if has_spatial_coords:
_plot_spatial_subplot(
ax_row[col_idx], df_wide, *spatial_cols,
actual_col, s =s, **plot_kwargs
)
else:
_plot_temporal_subplot(
ax_row[col_idx], df_wide, actual_col, **plot_kwargs
)
col_idx += 1
# Plot each requested prediction/quantile.
for q_suffix in quantiles_to_plot:
if col_idx >= len(ax_row): continue
pred_col = plot_structure[prefix].get(year, {}).get(q_suffix)
# Format the title for the subplot.
title_suffix = (
q_suffix.replace('q', 'Q').capitalize()
if 'q' in q_suffix else "Prediction"
)
title = f"{title_suffix} ({year})"
plot_kwargs['title'] = title
# Choose plotting function based on coordinate availability.
if has_spatial_coords:
_plot_spatial_subplot(
ax_row[col_idx], df_wide, *spatial_cols,
pred_col, s=s, **plot_kwargs
)
else:
_plot_temporal_subplot(
ax_row[col_idx], df_wide, pred_col, **plot_kwargs
)
col_idx += 1
# Turn off any remaining unused axes in the current row.
for i in range(col_idx, len(ax_row)):
ax_row[i].axis('off')
def _plot_single_scatter(
ax, df, x_col, y_col, c_col, cmap, vmin, vmax, title,
axis_off, show_grid, grid_props
):
"""Helper to create a single scatter subplot."""
if c_col not in df.columns:
ax.set_title(f"{title}\n(Data not found)",
fontsize=10, color='red')
ax.set_xticks([])
ax.set_yticks([])
return None
# Drop NaNs for this specific plot to avoid plotting empty points
plot_df = df[[x_col, y_col, c_col]].dropna()
if plot_df.empty:
ax.set_title(f"{title}\n(No valid data)", fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
return None
scatter = ax.scatter(
plot_df[x_col], plot_df[y_col], c=plot_df[c_col],
cmap=cmap, vmin=vmin, vmax=vmax, s=10
)
ax.set_title(title, fontsize=10)
if axis_off:
ax.axis('off')
else:
ax.tick_params(axis='both', which='major', labelsize=8)
if show_grid:
ax.grid(**grid_props)
return scatter
def plot_eval_future(
df_eval: Optional[pd.DataFrame] = None,
df_future: Optional[pd.DataFrame] = None,
target_name: str = 'subsidence',
quantiles: Optional[Sequence[float]] = None,
spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'),
time_col: str = 'coord_t',
eval_years: Optional[Sequence[Any]] = None,
future_years: Optional[Sequence[Any]] = None,
eval_view_quantiles: Optional[Sequence[Any]] = None,
future_view_quantiles: Optional[Sequence[Any]] = None,
cmap: str = 'viridis',
cbar: str = 'uniform',
axis_off: bool = False,
show_grid: bool = True,
grid_props: Optional[Dict[str, Any]] = None,
figsize_eval: Optional[Tuple[float, float]] = None,
figsize_future: Optional[Tuple[float, float]] = None,
spatial_mode: str = 'hexbin',
hexbin_gridsize: int = 40,
savefig_prefix: Optional[str] = None,
save_fmts: Union[str, Sequence[str]] = '.png',
show: bool = True,
verbose: int = 1,
cumulative: bool = False,
_logger: Optional[Union[logging.Logger,
Callable[[str], None]]] = None,
**kws,
) -> None:
"""
Convenience wrapper to visualize evaluation and future forecasts.
This function provides a high-level interface around
:func:`forecast_view` to jointly visualize:
* an **evaluation split** (``df_eval``) where actual vs predicted
values are shown side-by-side for selected years, and
* a **future split** (``df_future``) where only predicted
quantiles are shown for future horizons.
It is designed to work seamlessly with the outputs of
:func:`fusionlab.nn.pinn.utils.format_and_forecast`, where
the data frames contain columns such as:
* ``sample_idx``
* ``forecast_step``
* ``coord_t`` (time), ``coord_x``, ``coord_y`` (spatial coords)
* ``<target_name>_actual`` (evaluation only)
* ``<target_name>_qXX`` or ``<target_name>_pred`` (predictions)
Parameters
----------
df_eval : pandas.DataFrame, optional
Evaluation DataFrame with actual and predicted values for
past / validation horizons. Typically the first return value
of :func:`fusionlab.nn.pinn.utils.format_and_forecast` when
called on a validation or test split.
If ``None`` or empty, the evaluation panel is skipped.
df_future : pandas.DataFrame, optional
Future forecast DataFrame with predicted values only, usually
the second return value of
:func:`fusionlab.nn.pinn.utils.format_and_forecast` when
called on a validation or test split (for training years) or
true-future split (e.g. 2023–2025). If ``None`` or empty, the
future panel is skipped.
target_name : str, default='subsidence'
Base name of the target variable to plot. The function will
look for columns that start with this prefix, e.g.
``"subsidence_q10"``, ``"subsidence_q50"``,
``"subsidence_q90"`` or ``"subsidence_pred"``.
quantiles : sequence of float, optional
List of quantiles corresponding to the prediction columns,
e.g. ``[0.1, 0.5, 0.9]`` for ``q10``, ``q50``, ``q90``. If
provided, these are used as defaults for both evaluation
and future views unless overridden via
``eval_view_quantiles`` or ``future_view_quantiles``.
spatial_cols : tuple of str, default=('coord_x', 'coord_y')
Names of the longitude / latitude (or generic x/y) columns in
``df_eval`` and ``df_future`` used for the spatial layout.
time_col : str, default='coord_t'
Name of the time column in ``df_eval`` and ``df_future``.
This column is used to select which years (or time stamps) to
display and to structure the grid (one row per time slice).
eval_years : sequence of hashable, optional
List of years (or time values in ``time_col``) to visualize in
the evaluation split. If ``None``, all unique values of
``time_col`` in ``df_eval`` are collected and only the last
one (e.g. the final validation year) is plotted.
future_years : sequence of hashable, optional
List of years (or time values in ``time_col``) to visualize in
the future split. If ``None``, all unique values of
``time_col`` in ``df_future`` are used.
eval_view_quantiles : sequence, optional
Quantiles to display for the evaluation split. Elements can be
floats (e.g. ``0.1``) or strings (e.g. ``"q10"``). If
``None``, defaults to ``quantiles`` (if provided), otherwise
all available quantile columns for ``target_name`` are used.
future_view_quantiles : sequence, optional
Quantiles to display for the future split. Same conventions as
``eval_view_quantiles``. If ``None``, defaults to
``quantiles`` (if provided), otherwise all available quantile
columns are used.
cmap : str, default='viridis'
Matplotlib colormap name used for the value maps.
cbar : {"uniform", "independent"}, default='uniform'
Controls how colorbars are scaled:
* ``"uniform"`` : all subplots share a global ``vmin``/``vmax``
inferred from the entire panel for consistent comparison.
* ``"independent"`` : each subplot uses its own color scale.
axis_off : bool, default=False
If ``True``, turn off axes (ticks and frames) in the spatial
plots for a cleaner, map-like appearance.
show_grid : bool, default=True
If ``True``, overlay a light grid on each subplot (e.g.
longitude/latitude grid), controlled by ``grid_props``.
grid_props : dict, optional
Keyword arguments passed to the underlying grid plotting
(e.g. ``{"linestyle": ":", "alpha": 0.7}``). If ``None``,
a light dotted grid is used by default.
figsize_eval : tuple of float, optional
Figure size (width, height) in inches for the evaluation
panel. If ``None``, a size is chosen automatically based on
the number of columns and rows.
figsize_future : tuple of float, optional
Figure size (width, height) in inches for the future panel.
If ``None``, a size is chosen automatically.
spatial_mode : {"hexbin", "scatter"}, default='hexbin'
Spatial visualization mode:
* ``"hexbin"`` : use hexagonal binning (`hexbin`) to show
spatial hotspots, suitable for dense datasets.
* ``"scatter"`` : plot individual points as a scatter map.
hexbin_gridsize : int, default=40
Grid size passed to ``matplotlib.axes.Axes.hexbin`` when
``spatial_mode="hexbin"``. Larger values yield finer spatial
resolution.
savefig_prefix : str, optional
If provided, figures are saved using this prefix. The
evaluation split is saved as ``"{prefix}_eval.*"`` and the
future split as ``"{prefix}_future.*"`` with the extensions
controlled by ``save_fmts``.
save_fmts : str or sequence of str, default='.png'
File format(s) for saving figures, e.g. ``".png"`` or
``[".png", ".pdf"]``. Ignored if ``savefig_prefix`` is
``None``.
show : bool, default=True
If ``True``, display figures with ``plt.show()``. If ``False``
and ``savefig_prefix`` is not ``None``, figures are only
saved to disk and closed.
verbose : int, default=1
Verbosity level for logging. A value of ``0`` silences all
messages, higher values enable more detailed logs via
:func:`fusionlab.utils.generic_utils.vlog`.
cumulative : bool, default=False
If ``True``, instructs :func:`forecast_view` (for the
*evaluation* split) to interpret the predictions for
``target_name`` as *cumulative along the forecast horizon*,
instead of per-step rates. This is useful when
:func:`fusionlab.nn.pinn.utils.format_and_forecast` has been
configured to output relative or absolute cumulative values
(e.g. cumulative subsidence at each year). When ``False``,
per-step (rate-like) predictions are visualized.
Currently this flag is forwarded to the evaluation split
only; the future split uses whatever representation is
present in ``df_future``.
_logger : logging.Logger or callable, optional
Optional logger instance or callable used for internal
messages. If ``None``, messages fall back to
:func:`fusionlab.utils.generic_utils.vlog`.
**kws
Additional keyword arguments forwarded to
:func:`forecast_view`, allowing fine-grained control over
plotting behaviour.
Returns
-------
None
The function creates and optionally saves matplotlib figures,
but does not return any objects.
Examples
--------
Basic usage with validation and future forecasts:
>>> from fusionlab.plot.forecast import plot_eval_future
>>> df_eval, df_future = df_eval_val, df_future_val
>>> plot_eval_future(
... df_eval=df_eval,
... df_future=df_future,
... target_name="subsidence",
... quantiles=[0.1, 0.5, 0.9],
... spatial_cols=("coord_x", "coord_y"),
... time_col="coord_t",
... eval_years=[2022], # last validation year
... future_years=[2023, 2024], # first two forecast years
... )
Plotting evaluation as cumulative subsidence (e.g. when
``format_and_forecast`` has been configured with cumulative
outputs), while keeping the future panel with the raw values in
``df_future``:
>>> from fusionlab.plot.forecast import plot_eval_future
>>> plot_eval_future(
... df_eval=df_eval,
... df_future=df_future,
... target_name="subsidence",
... quantiles=[0.1, 0.5, 0.9],
... eval_view_quantiles=[0.5], # only median for eval
... future_view_quantiles=[0.1, 0.5, 0.9],
... spatial_mode="hexbin",
... cumulative=True, # show eval as cumulative
... savefig_prefix="zhongshan_subsidence_view",
... save_fmts=[".png", ".pdf"],
... show=False,
... )
"""
if grid_props is None:
grid_props = {'linestyle': ':', 'alpha': 0.7}
if df_eval is None or df_eval.empty:
msg = 'plot_eval_future: no eval data provided.'
vlog(msg, level=3, verbose=verbose, logger=_logger)
has_eval = False
else:
has_eval = True
if df_future is None or df_future.empty:
msg = 'plot_eval_future: no future data provided.'
vlog(msg, level=3, verbose=verbose, logger=_logger)
has_future = False
else:
has_future = True
if not has_eval and not has_future:
msg = 'plot_eval_future: nothing to plot.'
vlog(msg, level=2, verbose=verbose, logger=_logger)
return
if quantiles is not None:
quantiles = list(quantiles)
if eval_view_quantiles is None:
eval_view_quantiles = quantiles
else:
eval_view_quantiles = list(eval_view_quantiles)
if future_view_quantiles is None:
future_view_quantiles = quantiles
else:
future_view_quantiles = list(future_view_quantiles)
# ----------------- EVAL SPLIT: actual vs predicted -----------------
if has_eval:
if eval_years is None:
years_eval = sorted(
df_eval[time_col].dropna().unique().tolist()
)
# default = last available eval year
years_eval = years_eval[-1:]
else:
years_eval = list(eval_years)
eval_save = None
if savefig_prefix is not None:
eval_save = f"{savefig_prefix}_eval"
vlog(
'plot_eval_future: plotting eval split.',
level=2,
verbose=verbose,
logger=_logger,
)
forecast_view(
forecast_df=df_eval,
value_prefixes=[target_name],
kind='dual', # [actual] [q10/q50/q90] layout
view_quantiles=eval_view_quantiles,
view_years=years_eval,
spatial_cols=spatial_cols,
time_col=time_col,
cmap=cmap,
cbar=cbar,
axis_off=axis_off,
show_grid=show_grid,
grid_props=grid_props,
figsize=figsize_eval,
savefig=eval_save,
save_fmts=save_fmts,
show=show,
verbose=verbose,
_logger=_logger,
spatial_mode=spatial_mode,
hexbin_gridsize=hexbin_gridsize,
cumulative=cumulative,
**kws,
)
# --------------- FUTURE SPLIT: only predictions --------------------
if has_future:
if future_years is None:
years_future = sorted(
df_future[time_col].dropna().unique().tolist()
)
else:
years_future = list(future_years)
future_save = None
if savefig_prefix is not None:
future_save = f"{savefig_prefix}_future"
vlog(
'plot_eval_future: plotting future split.',
level=2,
verbose=verbose,
logger=_logger,
)
forecast_view(
forecast_df=df_future,
value_prefixes=[target_name],
kind='pred_only', # only q's, no actual
view_quantiles=future_view_quantiles,
view_years=years_future,
spatial_cols=spatial_cols,
time_col=time_col,
cmap=cmap,
cbar=cbar,
axis_off=axis_off,
show_grid=show_grid,
grid_props=grid_props,
figsize=figsize_future,
savefig=future_save,
save_fmts=save_fmts,
show=show,
verbose=verbose,
_logger=_logger,
spatial_mode=spatial_mode,
hexbin_gridsize=hexbin_gridsize,
**kws,
)