Source code for fusionlab.plot.forecast

# -*- coding: utf‑8 -*-
# License: BSD‑3‑Clause
# Author: LKouadio <etanoyau@gmail.com>

"""
Convenience helpers for **quick‑look** visual inspection of model
forecasts.
"""

from __future__ import annotations

import os
import re
import logging 
import warnings
from typing import Any, List, Optional, Tuple, Union, Dict, Callable 

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from matplotlib.cm import ScalarMappable
import numpy as np
import pandas as pd

from ..core.checks import ( 
    check_non_emptiness, 
    check_spatial_columns, 
    is_in_if
)
from ..core.handlers import columns_manager
from ..utils.forecast_utils import ( 
    format_forecast_dataframe, 
    get_value_prefixes, 
    get_value_prefixes_in, 
    detect_forecast_type, 
    get_step_names 
)
from ..utils.generic_utils import ( 
    _coerce_dt_kw, 
    get_actual_column_name,
    vlog, save_figure 
)
from ..utils.validator import ( 
    assert_xy_in, is_frame, 
    validate_positive_integer 
)

__all__= [
    "forecast_view",
    "plot_forecast_by_step", 
    "plot_forecasts", 
    "visualize_forecasts", 
 ]

[docs] def plot_forecast_by_step( df: pd.DataFrame, value_prefixes: Optional[List[str]] = None, kind: str = "dual", steps: Optional[int] =None, # step to plot step_names: Optional[Union[Dict[int, str], List[str]]] = None, spatial_cols: Optional[Tuple[str, str]] = None, max_cols: Union[int, str] = "auto", cmap: str = "viridis", cbar: str = "uniform", axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = ".png", _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, verbose: int = 1, ): """Plots forecast data, organizing subplots by forecast step. This function creates a grid of plots to visualize forecast data from a long-format DataFrame. Each row in the grid corresponds to a unique `forecast_step`. It robustly handles both spatial scatter plots (if `spatial_cols` are provided and found) and temporal line plots (as a fallback), making it versatile for different types of forecast data. Parameters ---------- df : pd.DataFrame Input long-format DataFrame. Must contain a 'forecast_step' column and value columns (e.g., 'subsidence_q50'). value_prefixes : list of str, optional The base names of metrics to plot (e.g., ``['subsidence']``). If ``None``, prefixes are auto-detected from column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison. - ``'pred_only'``: Plots only the predicted values. steps : int or list of int, optional A specific forecast step or list of steps to visualize. If ``None``, all unique steps found in the DataFrame are plotted. step_names : dict or list, optional Custom labels for the forecast steps in plot titles. - If a dict, maps step numbers to names (e.g., ``{1: 'Year 1'}``). - If a list, maps step index to names (e.g., ``['Y1', 'Y2']``, where index 0 corresponds to step 1). spatial_cols : tuple of str, optional Tuple of column names for spatial coordinates, e.g., ``('coord_x', 'coord_y')``. If provided and found, spatial scatter plots are created. If ``None`` or columns are not found, the function falls back to temporal line plots. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. If ``'auto'``, it's determined by the number of metrics (e.g., actual + quantiles) being plotted for each prefix. cmap : str, default 'viridis' The Matplotlib colormap for spatial plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls color bar scaling for spatial plots. - ``'uniform'``: All subplots share a single color scale. - ``'individual'``: Each subplot has its own color scale. axis_off : bool, default False If ``True``, turns off plot axes, ticks, and labels. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, displays a grid on the subplots. grid_props : dict, optional Properties for the grid, e.g., ``{'linestyle': '--'}``. figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the layout. savefig : str, optional Path to save the figure(s). The prefix name and '_by_step' will be appended (e.g., 'my_plot_subsidence_by_step.png'). save_fmts : str or list of str, default '.png' Format(s) to save the figure, e.g., ``['.png', '.pdf']``. verbose : int, default 1 Controls the verbosity of logging messages. Returns ------- None This function displays and/or saves Matplotlib figures and does not return any value. See Also -------- forecast_view : A related function that plots by year/time instead of by step index. Notes ----- - The function expects a long-format DataFrame where each row is a unique combination of a sample and a forecast step. - If `spatial_cols` are not found, temporal line plots are generated by plotting values against the DataFrame's index. This is most meaningful when the DataFrame is sorted by a relevant variable (e.g., a spatial coordinate) for each step. Examples -------- >>> import pandas as pd >>> # from fusionlab.plot.forecast import plot_forecast_by_step >>> data = { ... 'sample_idx': list(range(4)) * 2, ... 'forecast_step': [1] * 4 + [2] * 4, ... 'coord_x': np.random.rand(8), ... 'coord_y': np.random.rand(8), ... 'subsidence_q50': np.random.randn(8), ... 'subsidence_actual': np.random.randn(8), ... } >>> df_step_example = pd.DataFrame(data) >>> # plot_forecast_by_step( ... # df=df_step_example, ... # value_prefixes=['subsidence'], ... # spatial_cols=('coord_x', 'coord_y'), ... # step_names={1: 'Year 1 Forecast', 2: 'Year 2 Forecast'}, ... # kind='dual' ... # ) """ # 1. Detect prefixes if not provided. if value_prefixes is None: vlog("Auto-detecting value prefixes…", level=1, verbose=verbose, logger =_logger) # Define default columns to exclude from prefix detection. exclude_from_prefix = ['sample_idx', 'forecast_step'] if spatial_cols: exclude_from_prefix.extend(spatial_cols) value_prefixes = get_value_prefixes_in(df, exclude_cols=exclude_from_prefix) if not value_prefixes: raise ValueError("Could not auto-detect value prefixes.") vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose, logger =_logger) # Ensure forecast_step column exists. if "forecast_step" not in df.columns: raise ValueError("DataFrame must contain 'forecast_step'.") # Collect metric suffixes found in columns. metrics = _get_metrics_from_cols(df.columns, value_prefixes) actual_metrics = sorted([m for m in metrics if "actual" in m]) pred_metrics = sorted([m for m in metrics if "actual" not in m]) # Determine if actuals will be plotted. plot_actuals = kind == "dual" and bool(actual_metrics) # Check for spatial plotting capability. has_spatial_coords = spatial_cols and all( c in df.columns for c in spatial_cols) if kind == 'spatial' and not has_spatial_coords: warnings.warn( f"kind='spatial' but columns {spatial_cols} not found. " "Falling back to temporal line plots." ) kind = 'temporal' # Fallback kind elif kind == 'spatial': vlog("Plotting spatial data.", level=2, verbose=verbose, logger =_logger) else: # kind is 'temporal' vlog("Plotting temporal data.", level=2, verbose=verbose, logger =_logger) # Unique steps to visualize. steps_to_plot = sorted(df["forecast_step"].unique()) if steps is not None: steps = columns_manager( steps, empty_as_none= False) steps = [validate_positive_integer(v, f"Step {v}") for v in steps ] steps_to_plot = sorted (is_in_if( steps_to_plot, steps, return_intersect= True, ) ) n_rows = len(steps_to_plot) if n_rows == 0: vlog("No forecast steps to plot.", level=0, verbose=verbose, logger =_logger) return # 2. Compute global color limits if cbar is 'uniform'. vmin, vmax = None, None if cbar == "uniform": cols_for_cbar = [ f"{p}_{m}" for p in value_prefixes for m in metrics if f"{p}_{m}" in df.columns ] if cols_for_cbar: vmin = df[cols_for_cbar].dropna().min().min() vmax = df[cols_for_cbar].dropna().max().max() # Shared kwargs for plotting helpers. plot_kwargs = dict( cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off, show_grid=show_grid, grid_props=grid_props or {"linestyle": ":", "alpha": 0.7}, verbose=verbose, ) # 3. Loop over each prefix and produce a grid of sub-plots. for prefix in value_prefixes: # Select relevant metrics for this prefix. prefix_preds = [ m for m in pred_metrics if f"{prefix}_{m}" in df.columns ] prefix_actual = next( (m for m in actual_metrics if f"{prefix}_{m}" in df.columns), None, ) # Determine number of columns for the grid. if max_cols == "auto": n_cols = len(prefix_preds) + ( 1 if plot_actuals and prefix_actual else 0 ) else: n_cols = int(max_cols) if n_cols == 0: continue # Create figure and axes grid. fig_size = figsize or (n_cols * 4, n_rows * 3.5) fig, axes = plt.subplots( n_rows, n_cols, figsize=fig_size, squeeze=False, constrained_layout=True, ) fig.suptitle( f"Forecast for '{prefix.upper()}' by Step", fontsize=16, weight='bold' ) # Iterate over each forecast step (row). for r_idx, step in enumerate(steps_to_plot): step_df = df[df["forecast_step"] == step] step_lbl = f"Step {step}" # Custom label overrides. if isinstance(step_names, dict): step_lbl = step_names.get(step, step_lbl) elif isinstance(step_names, list) and step - 1 < len(step_names): step_lbl = f"Step {step}: {step_names[step - 1]}" c_idx = 0 # Plot actuals if requested and available. if plot_actuals and prefix_actual: ax = axes[r_idx][c_idx] plot_kwargs["title"] = f"Actual ({step_lbl})" plot_col = f"{prefix}_{prefix_actual}" if kind == 'spatial': _plot_spatial_subplot( ax, step_df, *spatial_cols, plot_col, **plot_kwargs ) else: _plot_temporal_subplot( ax, step_df, plot_col, **plot_kwargs ) c_idx += 1 # Plot prediction metrics. for metric in prefix_preds: if c_idx >= n_cols: break ax = axes[r_idx][c_idx] plot_kwargs["title"] = ( f"{metric.replace('_', ' ').title()} ({step_lbl})" ) plot_col = f"{prefix}_{metric}" if kind == 'spatial': _plot_spatial_subplot( ax, step_df, *spatial_cols, plot_col, **plot_kwargs ) else: _plot_temporal_subplot( ax, step_df, plot_col, **plot_kwargs ) c_idx += 1 # Blank unused axes in this row. for j in range(c_idx, n_cols): axes[r_idx][j].axis("off") # Add a single color-bar if using uniform scaling. if cbar == "uniform" and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap ), ax=axes.ravel().tolist(), orientation="vertical", shrink=0.8, pad=0.04 ) # 4. Save figure to disk if requested. if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" ) plt.close(fig) else: plt.show()
@check_non_emptiness def forecast_view_in( forecast_df: pd.DataFrame, value_prefixes: Optional[List[str]]=None, kind: str = 'dual', view_quantiles: Optional[List[Union[str, float]]] = None, view_years: Optional[List[Union[str, int]]] = None, spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'), time_col='coord_t', max_cols: Union[int, str] = 'auto', cmap: str = 'viridis', cbar: str = 'uniform', axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = '.png', verbose: int = 1 ): """Generates and displays spatial forecast visualizations. This function creates a grid of scatter plots to visualize spatio-temporal forecast data. It can handle both long-format and wide-format DataFrames, automatically arranging plots by year and metric (e.g., different quantiles). Parameters ---------- forecast_df : pd.DataFrame Input DataFrame containing forecast data. The function auto-detects if the format is long or wide. value_prefixes : list of str, optional The base names of the metrics to plot (e.g., ``['subsidence', 'GWL']``). If ``None``, the function will attempt to automatically infer these from the DataFrame's column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison, if actuals are available. - ``'pred_only'``: Plots only the predicted values. view_quantiles : list of str or float, optional A list to filter which quantiles are displayed. Values can be floats (e.g., ``[0.1, 0.5]``) or strings (e.g., ``['q10', 'q50']``). If ``None``, all detected quantiles are plotted. view_years : list of str or int, optional A list of years to display. If ``None``, all detected years in the forecast are plotted. spatial_cols : tuple of str, default ('coord_x', 'coord_y') A tuple containing the names of the columns to be used for the x and y axes of the scatter plots. time_col : str, default 'coord_t' The name of the column representing the time dimension in a long-format DataFrame. Used by the internal format detector. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. - If ``'auto'``, the number of columns is automatically set to the number of quantiles being plotted (plus one if `kind='dual'`). - If an integer, sets a fixed number of columns for the prediction plots. If `kind='dual'`, an additional column for the actuals is added, potentially exceeding this number. cmap : str, default 'viridis' The Matplotlib colormap to use for the scatter plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls the color bar scaling: - ``'uniform'``: All subplots in a figure share a single, uniform color scale, determined by the global min/max of all plotted data. A single color bar is displayed for the entire figure. - ``'individual'``: Each subplot has its own color bar scaled to its own data range. (Note: Current implementation defaults to uniform; individual color bars would be a future enhancement). axis_off : bool, default False If ``True``, turns off the axes (ticks, labels, spines) for all subplots. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, a grid is displayed on the subplots. grid_props : dict, optional A dictionary of properties to pass to `ax.grid()` (e.g., ``{'linestyle': ':', 'alpha': 0.7}``). figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the number of rows and columns. savefig : str, optional If a file path is provided (e.g., 'my_forecast.png'), the figure(s) will be saved. The prefix name will be appended to the filename (e.g., 'my_forecast_subsidence.png'). save_fmts : str or list of str, default '.png' The format(s) to save the figure in (e.g., ``['.png', '.pdf']``). verbose : int, default 1 Controls the verbosity of logging messages. `0` is silent, `1` provides basic info, and higher values provide more detail. Returns ------- None This function does not return any value. It displays and/or saves Matplotlib figures directly. See Also -------- format_forecast_dataframe : The utility used to detect and pivot the input DataFrame. get_value_prefixes : The utility used to auto-detect metric prefixes. Notes ----- - The function creates a grid where each row corresponds to a year from `view_years`, and columns correspond to the actuals (if `kind='dual'`) and each of the selected quantiles. - If an actual value for a specific year is not available, the function will attempt to fill that plot using the most recent known 'actual' value to facilitate comparison across the row. - Currently, a separate figure is generated for each prefix in `value_prefixes`. """ # Use the format utility to ensure we have a wide DataFrame # Auto-detect value prefixes if not provided if value_prefixes is None: vlog("`value_prefixes` not provided. Auto-detecting...", level=1, verbose=verbose) value_prefixes = get_value_prefixes(forecast_df) if not value_prefixes: raise ValueError( "Could not auto-detect any value prefixes. Please " "provide them explicitly." ) vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose) value_prefixes = columns_manager(value_prefixes) df_wide = format_forecast_dataframe( df= forecast_df, to_wide=True, value_prefixes=value_prefixes, # Pass relevant pivot args if df is long id_vars=[ c for c in ['sample_idx', *spatial_cols] if c in forecast_df.columns ], time_col=time_col, static_actuals_cols=[ c for c in forecast_df.columns if c.endswith('_actual') and c.count('_') == 1 ], verbose=verbose ) vlog("Parsing wide DataFrame columns to build plot structure.", level=1, verbose=verbose) plot_structure = _parse_wide_df_columns(df_wide, value_prefixes) # --- Filter data based on user preferences --- all_years = sorted(list( set(y for p_data in plot_structure.values() for y in p_data if y.isdigit()) )) years_to_plot = [str(y) for y in view_years] if view_years else all_years # Identify available and requested quantiles all_quantiles = sorted(list(set( suffix for p_data in plot_structure.values() for y_data in p_data.values() if isinstance(y_data, dict) for suffix in y_data if suffix.startswith('q') ))) if view_quantiles: # Normalize requested quantiles to 'qXX' format req_q_norm = [] for q in view_quantiles: if isinstance(q, float): req_q_norm.append(f"q{int(q*100)}") else: # is str req_q_norm.append(q if q.startswith('q') else f"q{q}") quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm] else: quantiles_to_plot = all_quantiles # Handle deterministic case (no quantiles found) is_deterministic = not bool(quantiles_to_plot) if is_deterministic: # Look for a default prediction column, e.g., 'subsidence_2022_pred' quantiles_to_plot = ['pred'] vlog("No quantiles found. Assuming deterministic forecast.", level=1, verbose=verbose) # --- Setup Plot Layout --- plot_actuals = (kind == 'dual') n_rows = len(years_to_plot) if n_rows == 0: vlog("No years to plot after filtering.", level=0, verbose=verbose) return # --- Prepare for plotting --- x_col, y_col = spatial_cols grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7} # Determine uniform color range if needed vmin, vmax = None, None if cbar == 'uniform': all_plot_cols = [] # Collect every column-reference that actually exists in df_wide def _collect_cols(node): """Recursively collect column names from plot_structure nodes.""" if isinstance(node, str): all_plot_cols.append(node) elif isinstance(node, dict): for sub in node.values(): _collect_cols(sub) for p_data in plot_structure.values(): _collect_cols(p_data) valid_plot_cols = [c for c in all_plot_cols if c in df_wide.columns] if valid_plot_cols: min_vals = [df_wide[c].dropna().min() for c in valid_plot_cols] max_vals = [df_wide[c].dropna().max() for c in valid_plot_cols] if any(pd.notna(v) for v in min_vals): vmin = min(v for v in min_vals if pd.notna(v)) if any(pd.notna(v) for v in max_vals): vmax = max(v for v in max_vals if pd.notna(v)) vmin = min(min_vals) vmax = max(max_vals) for prefix in value_prefixes: if max_cols == 'auto': n_cols_prefix = len(quantiles_to_plot) if plot_actuals: n_cols_prefix += 1 else: n_cols_prefix = int(max_cols) if n_cols_prefix == 0: continue figsize_prefix = figsize or (n_cols_prefix * 3.5, n_rows * 3) fig, axes = plt.subplots( n_rows, n_cols_prefix, figsize=figsize_prefix, squeeze=False, constrained_layout=True ) fig.suptitle( f"Forecast Visualization for '{prefix.upper()}'", fontsize=14, weight='bold' ) last_known_actual_col = plot_structure[prefix].get("static_actual") for row_idx, year in enumerate(years_to_plot): ax_row = axes[row_idx] col_idx = 0 if plot_actuals: actual_col = plot_structure[prefix].get( year, {}).get('actual', last_known_actual_col) if actual_col: last_known_actual_col = actual_col title = f"Actual ({year})" _plot_single_scatter( ax_row[col_idx], df_wide, x_col, y_col, actual_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ) col_idx += 1 for q_suffix in quantiles_to_plot: if col_idx >= n_cols_prefix: continue pred_col = plot_structure[prefix].get(year, {}).get(q_suffix) title = f"{q_suffix.replace('q', 'Q').capitalize()} ({year})" if is_deterministic: title = f"Prediction ({year})" _plot_single_scatter( ax_row[col_idx], df_wide, x_col, y_col, pred_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ) col_idx += 1 for i in range(col_idx, n_cols_prefix): ax_row[i].axis('off') if cbar == 'uniform' and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap), ax=axes, orientation='vertical', fraction=0.02, pad=0.04 ) if savefig: fmts = [save_fmts] if isinstance(save_fmts, str) else save_fmts base, _ = os.path.splitext(savefig) for fmt in fmts: fname = f"{base}_{prefix}{fmt if fmt.startswith('.') else '.' + fmt}" # --- ensure output directory exists -- out_dir = os.path.dirname(fname) if out_dir and not os.path.exists(out_dir): os.makedirs(out_dir, exist_ok=True) vlog(f"Saving figure to {fname}", level=1, verbose=verbose) fig.savefig(fname, dpi=300, bbox_inches="tight") plt.show()
[docs] @check_non_emptiness def forecast_view( forecast_df: pd.DataFrame, value_prefixes: Optional[List[str]] = None, kind: str = 'dual', view_quantiles: Optional[List[Union[str, float]]] = None, view_years: Optional[List[Union[str, int]]] = None, spatial_cols: Tuple[str, str] = None, time_col: str = 'coord_t', max_cols: Union[int, str] = 'auto', cmap: str = 'viridis', cbar: str = 'uniform', axis_off: bool = False, show_grid: bool = True, grid_props: Optional[Dict] = None, figsize: Optional[Tuple[float, float]] = None, savefig: Optional[str] = None, save_fmts: Union[str, List[str]] = '.png', dt_col: Optional[str] = None, show: bool =True, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, verbose: int = 1, **kws ): """Generates and displays spatial forecast visualizations. This function creates a grid of scatter plots to visualize spatio-temporal forecast data. It can handle both long-format and wide-format DataFrames, automatically arranging plots by year and metric (e.g., different quantiles). Parameters ---------- forecast_df : pd.DataFrame Input DataFrame containing forecast data. The function auto-detects if the format is long or wide. value_prefixes : list of str, optional The base names of the metrics to plot (e.g., ``['subsidence', 'GWL']``). If ``None``, the function will attempt to automatically infer these from the DataFrame's column names. A separate figure is generated for each prefix. kind : {'dual', 'pred_only'}, default 'dual' Determines what to plot: - ``'dual'``: Plots both actual values and predictions side-by-side for comparison, if actuals are available. - ``'pred_only'``: Plots only the predicted values. view_quantiles : list of str or float, optional A list to filter which quantiles are displayed. Values can be floats (e.g., ``[0.1, 0.5]``) or strings (e.g., ``['q10', 'q50']``). If ``None``, all detected quantiles are plotted. view_years : list of str or int, optional A list of years to display. If ``None``, all detected years in the forecast are plotted. spatial_cols : tuple of str, default ('coord_x', 'coord_y') A tuple containing the names of the columns to be used for the x and y axes of the scatter plots. time_col : str, default 'coord_t' The name of the column representing the time dimension in a long-format DataFrame. Used by the internal format detector. ..note:: 'time_col' and 'dt_col' can be used interchangeability. max_cols : int or 'auto', default 'auto' Controls the number of subplots per row. - If ``'auto'``, the number of columns is automatically set to the number of quantiles being plotted (plus one if `kind='dual'`). - If an integer, sets a fixed number of columns for the prediction plots. If `kind='dual'`, an additional column for the actuals is added, potentially exceeding this number. cmap : str, default 'viridis' The Matplotlib colormap to use for the scatter plots. cbar : {'uniform', 'individual'}, default 'uniform' Controls the color bar scaling: - ``'uniform'``: All subplots in a figure share a single, uniform color scale, determined by the global min/max of all plotted data. A single color bar is displayed for the entire figure. - ``'individual'``: Each subplot has its own color bar scaled to its own data range. (Note: Current implementation defaults to uniform; individual color bars would be a future enhancement). axis_off : bool, default False If ``True``, turns off the axes (ticks, labels, spines) for all subplots. show_grid : bool, default True If ``True`` and `axis_off` is ``False``, a grid is displayed on the subplots. grid_props : dict, optional A dictionary of properties to pass to `ax.grid()` (e.g., ``{'linestyle': ':', 'alpha': 0.7}``). figsize : tuple of (float, float), optional The size of the figure for each `value_prefix`. If ``None``, the size is automatically calculated based on the number of rows and columns. savefig : str, optional If a file path is provided (e.g., 'my_forecast.png'), the figure(s) will be saved. The prefix name will be appended to the filename (e.g., 'my_forecast_subsidence.png'). save_fmts : str or list of str, default '.png' The format(s) to save the figure in (e.g., ``['.png', '.pdf']``). verbose : int, default 1 Controls the verbosity of logging messages. `0` is silent, `1` provides basic info, and higher values provide more detail. kws: dict, Keywords arguments for feature extensions. Returns ------- None This function does not return any value. It displays and/or saves Matplotlib figures directly. See Also -------- format_forecast_dataframe : The utility used to detect and pivot the input DataFrame. get_value_prefixes : The utility used to auto-detect metric prefixes. Notes ----- - The function creates a grid where each row corresponds to a year from `view_years`, and columns correspond to the actuals (if `kind='dual'`) and each of the selected quantiles. - If an actual value for a specific year is not available, the function will attempt to fill that plot using the most recent known 'actual' value to facilitate comparison across the row. - Currently, a separate figure is generated for each prefix in `value_prefixes`. """ # canonicalise column name kw = _coerce_dt_kw( dt_col=dt_col, time_col=time_col, _time_default=time_col) time_col = kw.get("dt_col", time_col) # adopt canonical name _spatial_cols = spatial_cols or [] is_q = detect_forecast_type( forecast_df, value_prefixes=value_prefixes ) if is_q =='deterministic': warnings.warn ( "Deterministic prediction is detected. It is recommended" " to use 'fusionlab.plot.forecast.plot_forecast_by_step'" " instead. Use at your own risk." ) if value_prefixes is None: vlog("`value_prefixes` not provided. Auto-detecting...", level=1, verbose=verbose, logger =_logger) value_prefixes = get_value_prefixes( forecast_df, exclude_cols=['sample_idx', *_spatial_cols, time_col, 'forecast_step'] ) if not value_prefixes: raise ValueError( "Could not auto-detect any value prefixes. " "Please provide them explicitly." ) vlog(f"Detected prefixes: {value_prefixes}", level=2, verbose=verbose, logger =_logger) df_wide = format_forecast_dataframe( forecast_df, to_wide=True, value_prefixes=value_prefixes, id_vars=[c for c in ['sample_idx', *_spatial_cols] if c in forecast_df.columns], time_col=time_col, static_actuals_cols=[ c for c in forecast_df.columns if c.endswith('_actual') and c.count('_') == 1 ], spatial_cols = _spatial_cols, verbose=verbose, _logger =_logger, ) plot_structure = _parse_wide_df_columns(df_wide, value_prefixes) all_years = sorted([ y for p_data in plot_structure.values() for y in p_data if y.isdigit() ]) years_to_plot = [str(y) for y in view_years] if view_years else all_years all_quantiles = sorted(list(set( s for p in plot_structure.values() for y in p.values() if isinstance(y, dict) for s in y if s.startswith('q') ))) if view_quantiles: req_q_norm = [ f"q{int(q*100)}" if isinstance(q, float) else (q if q.startswith('q') else f"q{q}") for q in view_quantiles ] quantiles_to_plot = [q for q in all_quantiles if q in req_q_norm] else: quantiles_to_plot = all_quantiles if not quantiles_to_plot: quantiles_to_plot = ['pred'] vlog("No quantiles found. Assuming deterministic forecast.", level=1, verbose=verbose, logger =_logger) n_rows = len(years_to_plot) if n_rows == 0: vlog("No years to plot after filtering.", level=0, verbose=verbose, logger =_logger) return vmin, vmax = None, None if cbar == 'uniform': all_plot_cols = [ c for p in plot_structure.values() for y in p.values() for c in (y.values() if isinstance(y, dict) else [y]) if c in df_wide ] if all_plot_cols: min_vals = [df_wide[c].dropna().min() for c in all_plot_cols] max_vals = [df_wide[c].dropna().max() for c in all_plot_cols] if any(pd.notna(v) for v in min_vals): vmin = min(v for v in min_vals if pd.notna(v)) if any(pd.notna(v) for v in max_vals): vmax = max(v for v in max_vals if pd.notna(v)) plot_kwargs = dict( cmap=cmap, vmin=vmin, vmax=vmax, axis_off=axis_off, show_grid=show_grid, grid_props=(grid_props or {'linestyle': ':', 'alpha': 0.7}), verbose=verbose ) for prefix in value_prefixes: num_pred_cols = len(quantiles_to_plot) num_actual_cols = 1 if kind == 'dual' else 0 if max_cols == 'auto': n_cols = num_pred_cols + num_actual_cols else: n_cols = int(max_cols) if n_cols == 0: continue fig_size = figsize or (n_cols * 3.5, n_rows * 3) fig, axes = plt.subplots( n_rows, n_cols, figsize=fig_size, squeeze=False, constrained_layout=True ) fig.suptitle( f"Forecast Visualization for '{prefix.upper()}'", fontsize=14, weight='bold' ) _plot_forecast_grid( fig, axes, df_wide, plot_structure, years_to_plot, quantiles_to_plot, prefix, kind, spatial_cols, s = kws.pop('s', 10), plot_kwargs=plot_kwargs, _logger=_logger, ) if cbar == 'uniform' and vmin is not None and vmax is not None: fig.colorbar( ScalarMappable( norm=mcolors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap), ax=axes.ravel().tolist(), orientation='vertical', shrink=0.8, pad=0.04 ) if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" ) plt.close(fig) else: if show: plt.show() # for notebook debugging else: plt.close(fig)
[docs] @check_non_emptiness def plot_forecasts( forecast_df: pd.DataFrame, target_name: str = "target", quantiles: Optional[List[float]] = None, output_dim: int = 1, kind: str = "temporal", actual_data: Optional[pd.DataFrame] = None, dt_col: Optional[str] = None, actual_target_name: Optional[str] = None, sample_ids: Optional[Union[int, List[int], str]] = "first_n", num_samples: int = 3, horizon_steps: Optional[Union[int, List[int]]] = 1, spatial_cols: Optional[List[str]] = None, max_cols: int = 2, figsize: Tuple[float, float] = (8, 4.5), scaler: Optional[Any] = None, scaler_feature_names: Optional[List[str]] = None, target_idx_in_scaler: Optional[int] = None, titles: Optional[List[str]] = None, cbar: Optional[str] = None, step_names : Dict[int, Any] =None, show_grid: bool = True, grid_props: Optional[Dict] = None, savefig: Optional [str]=None, save_fmts : Optional [Union[str, List[str]]]=".png", show: bool=True, verbose: int = 0, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, **plot_kwargs: Any ) -> None: """ Visualizes model forecasts from a structured DataFrame. This function generates plots to visualize time series forecasts, supporting temporal line plots for individual samples/items and spatial scatter plots for specific forecast horizons. It can handle point forecasts and quantile (probabilistic) forecasts, optionally overlaying actual values for comparison. Predictions and actuals can be inverse-transformed if a scaler is provided. The input `forecast_df` is expected to be in a long format, typically generated by :func:`~fusionlab.nn.utils.format_predictions_to_dataframe`, containing 'sample_idx' and 'forecast_step' columns, along with prediction columns (e.g., '{target_name}_pred', '{target_name}_q50') and optionally actual value columns (e.g., '{target_name}_actual'). Parameters ---------- forecast_df : pd.DataFrame A pandas DataFrame containing the forecast data in long format. Must include 'sample_idx' and 'forecast_step' columns, along with prediction columns (and optionally actuals and spatial coordinate columns). target_name : str, default "target" The base name of the target variable. This is used to construct column names for predictions and actuals (e.g., "target_pred", "target_q50", "target_actual"). quantiles : List[float], optional A list of quantiles that were predicted (e.g., `[0.1, 0.5, 0.9]`). If provided, the function will plot the median quantile as a line and the range between the first and last quantile as a shaded uncertainty interval (for temporal plots). If ``None``, a point forecast is assumed. Default is ``None``. output_dim : int, default 1 The number of target variables (output dimensions) predicted at each time step. If > 1, separate subplots or plot groups may be generated for each output dimension. kind : {'temporal', 'spatial'}, default "temporal" The type of plot to generate: - ``'temporal'``: Plots forecasts over the horizon for selected samples (time series line plots). - ``'spatial'``: Creates scatter plots of forecast values over spatial coordinates for selected horizon steps. Requires `spatial_cols` to be specified. actual_data : pd.DataFrame, optional An optional DataFrame containing the true actual values for comparison. If `forecast_df` already contains actual value columns (e.g., '{target_name}_actual'), this may not be needed or can be used to supplement. If provided, `dt_col` and `actual_target_name` might be used for alignment or direct plotting. *Note: Current implementation primarily uses actuals from `forecast_df`.* dt_col : str, optional Name of the datetime column in `actual_data` or `forecast_df` if needed for x-axis labeling in temporal plots. If `forecast_df` uses 'forecast_step', this might be less critical unless aligning with a true date axis from `actual_data`. actual_target_name : str, optional The name of the target column in `actual_data` if it differs from `target_name`. If ``None``, `target_name` is assumed. sample_ids : int, List[int], str, default "first_n" Specifies which samples (based on 'sample_idx' in `forecast_df`) to plot for `kind='temporal'`. - If `int`: Plots the sample at that specific index. - If `List[int]`: Plots all samples with these indices. - If ``"first_n"``: Plots the first `num_samples`. - If ``"all"``: Plots all unique samples (can be many plots). num_samples : int, default 3 Number of samples to plot if `sample_ids="first_n"`. horizon_steps : int, List[int], str, default 1 Specifies which forecast horizon step(s) to plot for `kind='spatial'`. - If `int`: Plots the specified single step (1-indexed). - If `List[int]`: Plots all specified steps. - If ``"all"`` or ``None``: Plots all unique forecast steps available. spatial_cols : List[str], optional Required if `kind='spatial'`. A list of two column names from `forecast_df` representing the x and y spatial coordinates (e.g., `['longitude', 'latitude']`). Default is ``None``. max_cols : int, default 2 Maximum number of subplots to arrange per row in the figure. figsize : Tuple[float, float], default (8, 4.5) The size `(width, height)` in inches for **each subplot**. The total figure size will be inferred based on the number of rows and columns of subplots. scaler : Any, optional A fitted scikit-learn-like scaler object (must have an `inverse_transform` method). If provided along with `scaler_feature_names` and `target_idx_in_scaler`, prediction and actual columns related to `target_name` will be inverse-transformed before plotting. Default is ``None``. scaler_feature_names : List[str], optional A list of all feature names (in order) that the `scaler` was originally fit on. Required for targeted inverse transform if `scaler` is provided. Default is ``None``. target_idx_in_scaler : int, optional The index of the `target_name` (or the specific output dimension being plotted) within the `scaler_feature_names` list. Required for targeted inverse transform if `scaler` is provided. Default is ``None``. titles : List[str], optional A list of custom titles for the subplots. If provided, its length should match the number of subplots generated. Default is ``None`` (titles are auto-generated). cbar : {'uniform', None}, default=None Controls colour‑bar behaviour for *spatial* plots. Use ``'uniform'`` to compute a single global ``vmin`` / ``vmax`` and attach one shared colour‑bar to the figure. Any other value (or *None*) lets each subplot auto‑scale its own colour‑bar. step_names : dict[int, Any] or None, default=None Optional mapping that renames forecast‑step integers when building subplot titles—e.g. ``{1: "2021", 2: "2022"}``. Keys are coerced to ``int``; values are converted to ``str``. Missing steps fall back to *default_name* (``""`` is the step number itself). show_grid : bool, default=True Toggle the display of ``ax.grid`` in both temporal and spatial plots. Set to *False* for a cleaner look when overlaying graphics on a map. grid_props : dict or None, default=None Keyword arguments forwarded to :pyfunc:`matplotlib.axes.Axes.grid` (e.g. ``{"linestyle": ":", "alpha": 0.6}``). Ignored when *show_grid* is *False*. verbose : int, default 0 Verbosity level for logging during plot generation. - ``0``: Silent. - ``1`` or higher: Print informational messages. **plot_kwargs : Any Additional keyword arguments to pass to the underlying Matplotlib plotting functions (e.g., `ax.plot`, `ax.scatter`, `ax.fill_between`). Can be used to customize line styles, colors, marker sizes, etc. For example, `median_plot_kwargs={'color': 'blue'}`, `fill_between_kwargs={'color': 'lightblue'}`. Returns ------- None This function directly generates and shows/saves plots using Matplotlib and does not return any value. Raises ------ TypeError If `forecast_df` is not a pandas DataFrame. ValueError If essential columns like 'sample_idx' or 'forecast_step' are missing from `forecast_df`. If `kind='spatial'` is chosen but `spatial_cols` are not provided or not found in `forecast_df`. If an unsupported `kind` is specified. See Also -------- fusionlab.nn.utils.format_predictions_to_dataframe : Utility to generate the `forecast_df` expected by this function. matplotlib.pyplot.plot : Underlying plotting function for temporal forecasts. matplotlib.pyplot.scatter : Underlying plotting function for spatial forecasts. Examples -------- >>> from fusionlab.nn.utils import format_predictions_to_dataframe >>> from fusionlab.plot.forecast import plot_forecasts >>> import pandas as pd >>> import numpy as np >>> # Assume preds_point (B,H,O) and preds_quant (B,H,Q) are available >>> B, H, O, Q_len = 4, 3, 1, 3 >>> preds_point = np.random.rand(B, H, O) >>> preds_quant = np.random.rand(B, H, Q_len) >>> y_true_seq = np.random.rand(B, H, O) >>> quantiles_list = [0.1, 0.5, 0.9] >>> # Create DataFrames using format_predictions_to_dataframe >>> df_point = format_predictions_to_dataframe( ... predictions=preds_point, y_true_sequences=y_true_seq, ... target_name="value", forecast_horizon=H, output_dim=O ... ) >>> df_quant = format_predictions_to_dataframe( ... predictions=preds_quant, y_true_sequences=y_true_seq, ... target_name="value", quantiles=quantiles_list, ... forecast_horizon=H, output_dim=O ... ) >>> # Example 1: Plot temporal point forecast for first sample >>> # plot_forecasts(df_point, target_name="value", sample_ids=0) >>> # Example 2: Plot temporal quantile forecast for first 2 samples >>> # plot_forecasts(df_quant, target_name="value", ... # quantiles=quantiles_list, sample_ids="first_n", ... # num_samples=2, max_cols=1) >>> # Example 3: Spatial plot (requires spatial_cols in df_quant) >>> # Assume df_quant has 'lon' and 'lat' columns >>> # df_quant['lon'] = np.random.rand(len(df_quant)) * 10 >>> # df_quant['lat'] = np.random.rand(len(df_quant)) * 5 >>> # plot_forecasts(df_quant, target_name="value", ... # quantiles=quantiles_list, kind='spatial', ... # horizon_steps=1, spatial_cols=['lon', 'lat']) """ vlog(f"Starting forecast visualization (kind='{kind}')...", level=3, verbose=verbose, logger=_logger) if not isinstance(forecast_df, pd.DataFrame): raise TypeError("`forecast_df` must be a pandas DataFrame, typically " "the output of `format_predictions_to_dataframe`.") if 'sample_idx' not in forecast_df.columns or \ 'forecast_step' not in forecast_df.columns: raise ValueError( "`forecast_df` must contain 'sample_idx' and 'forecast_step' " "columns." ) # --- Data Preparation & Inverse Transform --- # Work on a copy to avoid modifying the original DataFrame df_to_plot = forecast_df.copy() actual_col_names_in_df = [] # To store names of actual columns in df_to_plot pred_col_names_in_df = [] # To store names of prediction columns # Identify prediction and actual columns based on target_name and quantiles base_pred_name = target_name base_actual_name = actual_target_name if actual_target_name \ else target_name if quantiles: # Sort quantiles to ensure consistent order for plotting (e.g., low, mid, high) sorted_quantiles = sorted([float(q) for q in quantiles]) for o_idx in range(output_dim): for q_val in sorted_quantiles: q_suffix = f"_q{int(q_val*100)}" col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += q_suffix if col_name in df_to_plot.columns: pred_col_names_in_df.append(col_name) # Actual column for this output dimension actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in df_to_plot.columns: actual_col_names_in_df.append(actual_col) else: # Point forecast for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += "_pred" if col_name in df_to_plot.columns: pred_col_names_in_df.append(col_name) actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in df_to_plot.columns: actual_col_names_in_df.append(actual_col) if not pred_col_names_in_df: warnings.warn("No prediction columns found in `forecast_df` " "based on `target_name` and `quantiles`.") # Attempt to find any column ending with _pred or _qXX pred_cols_found = [c for c in df_to_plot.columns if "_pred" in c or "_q" in c] if pred_cols_found: pred_col_names_in_df = pred_cols_found vlog(f" Auto-detected prediction columns: {pred_col_names_in_df}", level=2, verbose=verbose,logger=_logger) else: vlog(" No prediction columns could be auto-detected. " "Plotting may be limited.", level=2, verbose=verbose, logger=_logger) # Apply inverse transform if scaler is provided if scaler is not None: if scaler_feature_names is None or target_idx_in_scaler is None: warnings.warn( "Scaler provided, but `scaler_feature_names` or " "`target_idx_in_scaler` is missing. Inverse transform " "cannot be applied correctly to specific target columns.", UserWarning ) else: vlog(" Applying inverse transformation using provided scaler...", level=4, verbose=verbose, logger=_logger) cols_to_inverse_transform = pred_col_names_in_df + actual_col_names_in_df dummy_array_shape = (len(df_to_plot), len(scaler_feature_names)) for col_to_inv in cols_to_inverse_transform: if col_to_inv in df_to_plot.columns: try: dummy = np.zeros(dummy_array_shape) dummy[:, target_idx_in_scaler] = df_to_plot[col_to_inv] df_to_plot[col_to_inv] = scaler.inverse_transform( dummy )[:, target_idx_in_scaler] except Exception as e: warnings.warn( f"Failed to inverse transform column '{col_to_inv}'. " f"Plotting scaled value. Error: {e}" ) vlog(" Inverse transformation applied.", level=5, verbose=verbose, logger=_logger ) # --- Select Samples/Items to Plot --- unique_sample_ids = df_to_plot['sample_idx'].unique() selected_ids_for_plot = [] if isinstance(sample_ids, str): if sample_ids.lower() == "all": selected_ids_for_plot = unique_sample_ids elif sample_ids.lower() == "first_n": selected_ids_for_plot = unique_sample_ids[:num_samples] else: warnings.warn(f"Unknown string for `sample_ids`: " f"'{sample_ids}'. Plotting first sample.") selected_ids_for_plot = unique_sample_ids[:1] elif isinstance(sample_ids, int): if sample_ids < len(unique_sample_ids): selected_ids_for_plot = [unique_sample_ids[sample_ids]] else: warnings.warn(f"sample_idx {sample_ids} out of range. " "Plotting first sample.") selected_ids_for_plot = unique_sample_ids[:1] elif isinstance(sample_ids, list): selected_ids_for_plot = [ sid for sid in sample_ids if sid in unique_sample_ids ] if len(selected_ids_for_plot) ==0: vlog("No valid sample_idx found to plot. Defaulting to first available.", level=2, verbose=verbose, logger=_logger) selected_ids_for_plot = unique_sample_ids[:1] if len(selected_ids_for_plot) ==0 : # check if is empty vlog ("No sample to plot. Returning ...", level =1, verbose=verbose) return vlog(f" Plotting for sample_idx: {selected_ids_for_plot}", level=4, verbose=verbose, logger=_logger) # --- Plotting Logic --- if kind == "temporal": num_plots = len(selected_ids_for_plot) * output_dim if num_plots == 0: vlog("No data to plot for temporal type.", level=2, verbose=verbose, logger=_logger) return n_cols_plot = min(max_cols, num_plots) n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot fig, axes = plt.subplots( n_rows_plot, n_cols_plot, figsize=(n_cols_plot * figsize[0], n_rows_plot * figsize[1]), squeeze=False # Always return 2D array for axes ) axes_flat = axes.flatten() plot_idx = 0 for s_idx in selected_ids_for_plot: sample_df = df_to_plot[df_to_plot['sample_idx'] == s_idx] if sample_df.empty: continue for o_idx in range(output_dim): if plot_idx >= len(axes_flat): break # Should not happen ax = axes_flat[plot_idx] title = f"Sample ID: {s_idx}" if output_dim > 1: title += f", Target Dim: {o_idx}" if titles and plot_idx < len(titles): title = titles[plot_idx] ax.set_title(title) # Plot actuals if available actual_col = f"{base_actual_name}" if output_dim > 1: actual_col += f"_{o_idx}" actual_col += "_actual" if actual_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[actual_col], label=f'Actual {target_name}' f'{f"_{o_idx}" if output_dim > 1 else ""}', marker='o', linestyle='--' ) # Plot predictions if quantiles: median_q_val = 0.5 if 0.5 not in quantiles: # Find closest if 0.5 not present median_q_val = quantiles[len(quantiles) // 2] median_col = f"{base_pred_name}" if output_dim > 1: median_col += f"_{o_idx}" median_col += f"_q{int(median_q_val*100)}" lower_q_col = f"{base_pred_name}" if output_dim > 1: lower_q_col += f"_{o_idx}" lower_q_col += f"_q{int(sorted_quantiles[0]*100)}" upper_q_col = f"{base_pred_name}" if output_dim > 1: upper_q_col += f"_{o_idx}" upper_q_col += f"_q{int(sorted_quantiles[-1]*100)}" if median_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[median_col], label=f'Median (q{int(median_q_val*100)})', marker='x', **plot_kwargs.get("median_plot_kwargs", {}) ) if lower_q_col in sample_df.columns and \ upper_q_col in sample_df.columns: ax.fill_between( sample_df['forecast_step'], sample_df[lower_q_col], sample_df[upper_q_col], color='gray', alpha=0.3, label=f'Interval (q{int(sorted_quantiles[0]*100)}' f'-q{int(sorted_quantiles[-1]*100)})', **plot_kwargs.get("fill_between_kwargs", {}) ) else: # Point forecast pred_col = f"{base_pred_name}" if output_dim > 1: pred_col += f"_{o_idx}" pred_col += "_pred" if pred_col in sample_df.columns: ax.plot( sample_df['forecast_step'], sample_df[pred_col], label=f'Predicted {target_name}' f'{f"_{o_idx}" if output_dim > 1 else ""}', marker='x', **plot_kwargs.get("point_plot_kwargs", {}) ) ax.set_xlabel("Forecast Step into Horizon") ax.set_ylabel( f"{target_name}{f' (Dim {o_idx})' if output_dim > 1 else ''}" ) ax.legend() ax.grid(True) plot_idx += 1 # Hide unused subplots for i in range(plot_idx, len(axes_flat)): axes_flat[i].set_visible(False) fig.tight_layout() # plt.show() elif kind == "spatial": spatial_x_col = None spatial_y_col =None if spatial_cols is not None: spatial_cols = columns_manager (spatial_cols, empty_as_none=False) if len(spatial_cols )!=2: raise ValueError( "Spatial_cols need exactly two columns (e.g., longitude," f" latitude ). Got {spatial_cols}") spatial_x_col , spatial_y_col = spatial_cols if spatial_x_col is None or spatial_y_col is None: raise ValueError( "`spatial_x_col` and `spatial_y_col` must be provided " "for `kind='spatial'`." ) if spatial_x_col not in df_to_plot.columns or \ spatial_y_col not in df_to_plot.columns: raise ValueError( f"Spatial columns '{spatial_x_col}' or '{spatial_y_col}' " "not found in forecast_df." ) steps_to_plot = [] if isinstance(horizon_steps, int): steps_to_plot = [horizon_steps] elif isinstance(horizon_steps, list): steps_to_plot = horizon_steps elif horizon_steps is None or \ str(horizon_steps).lower() == "all": # Plot all unique forecast steps present in the data steps_to_plot = sorted(df_to_plot['forecast_step'].unique()) else: raise ValueError("Invalid `horizon_steps`.") num_plots = len(steps_to_plot) * output_dim if num_plots == 0: vlog("No data/steps to plot for spatial type.", level=2, verbose=verbose, logger=_logger) return vmin, vmax = None, None if cbar=="uniform": vlog("Calculating uniform color scale for all subplots...", level=2, verbose=verbose, logger=_logger) # Filter DataFrame to only the steps being plotted df_for_scaling = df_to_plot[ df_to_plot['forecast_step'].isin(steps_to_plot) ] # Identify all columns that will be used for coloring cols_to_scan = [] if quantiles: median_q = 0.5 if 0.5 not in quantiles: median_q = sorted(quantiles)[len(quantiles) // 2] for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += f"_q{int(median_q*100)}" cols_to_scan.append(col_name) else: # Point forecasts for o_idx in range(output_dim): col_name = f"{base_pred_name}" if output_dim > 1: col_name += f"_{o_idx}" col_name += "_pred" cols_to_scan.append(col_name) existing_color_cols = [ c for c in cols_to_scan if c in df_for_scaling.columns ] if existing_color_cols: # Calculate min/max on the filtered DataFrame min_vals = [ df_for_scaling[c].dropna().min() for c in existing_color_cols] max_vals = [df_for_scaling[c].dropna().max() for c in existing_color_cols] if any(pd.notna(v) for v in min_vals): vmin = min(v for v in min_vals if pd.notna(v)) if any(pd.notna(v) for v in max_vals): vmax = max(v for v in max_vals if pd.notna(v)) if vmin is not None and vmax is not None: vlog(f"Uniform color range set: vmin={vmin:.2f}, vmax={vmax:.2f}", level=3, verbose=verbose, logger=_logger) n_cols_plot = min(max_cols, num_plots) n_rows_plot = (num_plots + n_cols_plot - 1) // n_cols_plot fig, axes = plt.subplots( n_rows_plot, n_cols_plot, figsize=(n_cols_plot * figsize[0], n_rows_plot * figsize[1]), squeeze=False ) axes_flat = axes.flatten() plot_idx = 0 step_names = get_step_names( steps_to_plot, step_names =step_names, default_name= "{}" ) grid_props = grid_props or {'linestyle': ':', 'alpha': 0.7} for step in steps_to_plot: step_df = df_to_plot[df_to_plot['forecast_step'] == step] if step_df.empty: continue for o_idx in range(output_dim): if plot_idx >= len(axes_flat): break ax = axes_flat[plot_idx] # Determine column to use for color intensity color_col = None plot_title_suffix = "" if quantiles: # Use median for color median_q_val = 0.5 if 0.5 not in quantiles: median_q_val = quantiles[len(quantiles) // 2] color_col = f"{base_pred_name}" if output_dim > 1: color_col += f"_{o_idx}" color_col += f"_q{int(median_q_val*100)}" plot_title_suffix = f" (Median q{int(median_q_val*100)})" else: # Point forecast color_col = f"{base_pred_name}" if output_dim > 1: color_col += f"_{o_idx}" color_col += "_pred" if color_col not in step_df.columns: warnings.warn(f"Color column '{color_col}' not found " "for spatial plot. Skipping subplot.") plot_idx +=1 # Increment to avoid reusing subplot continue scatter_data = step_df.dropna( subset=[spatial_x_col, spatial_y_col, color_col] ) if scatter_data.empty: warnings.warn(f"No valid data to plot for step {step}, " f"output_dim {o_idx} after dropna. Skipping.") plot_idx += 1 continue # Normalize color data for better visualization if cbar !='uniform': norm = mcolors.Normalize( vmin=scatter_data[color_col].min(), vmax=scatter_data[color_col].max() ) else: norm=None # norm will be applied for uniform scale cmap = plot_kwargs.get('cmap', 'viridis') step_name = step_names.get(step, '{}') # for consistency sc = ax.scatter( scatter_data[spatial_x_col], scatter_data[spatial_y_col], c=scatter_data[color_col], cmap=cmap, norm=norm, s=plot_kwargs.get('s', 50), # Marker size alpha=plot_kwargs.get('alpha', 0.7) ) if cbar !='uniform': # use default behavior fig.colorbar(sc, ax=ax, label=f"{target_name}{plot_title_suffix}") title = step_name.format(f"Forecast Step: {step}") if output_dim > 1: title += f", Target Dim: {o_idx}" if titles and plot_idx < len(titles): title = titles[plot_idx] ax.set_title(title) ax.set_xlabel(spatial_x_col) ax.set_ylabel(spatial_y_col) if show_grid: ax.grid(True, **grid_props) else: ax.grid(False) plot_idx += 1 for i in range(plot_idx, len(axes_flat)): axes_flat[i].set_visible(False) fig.tight_layout() if cbar=='uniform': # This creates one colorbar for the whole # figure if uniform scale was used. # Adjust the main plot area to make space for the colorbar # This leaves 10% space on the right for the cbar fig.subplots_adjust(right=0.90) # Create a new axis for the colorbar # [left, bottom, width, height] in figure coordinates cax = fig.add_axes([0.92, 0.15, 0.015, 0.7]) norm = mcolors.Normalize(vmin=vmin, vmax=vmax) mappable = ScalarMappable(norm=norm, cmap=cmap) # Add colorbar to the figure fig.colorbar( mappable, cax=cax, # ax=axes.ravel().tolist(), # Associate with all axes label=f"{target_name} ({plot_title_suffix.strip()})", orientation='vertical', shrink=0.8, pad=0.04 ) else: raise ValueError( f"Unsupported `kind`: '{kind}'. " "Choose 'temporal' or 'spatial'." ) vlog("Forecast visualization complete.", level=3, verbose=verbose, logger=_logger) # 4. Save figure to disk if requested. if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" ) plt.close(fig) else: if show: plt.show() # for notebook debugging else: plt.close(fig)
[docs] @check_non_emptiness def visualize_forecasts( forecast_df, dt_col, tname, test_data=None, eval_periods=None, mode="quantile", kind="spatial", actual_name=None, x=None, y=None, cmap="coolwarm", max_cols=3, axis="on", s=2, show_grid=True, grid_props=None, savefig=None, save_fmts =None, verbose=1, **kw ): r""" Visualize forecast results and actual test data for one or more evaluation periods. The function plots a grid of scatter plots comparing actual values with forecasted predictions. Each evaluation period yields two plots: one for actual values and one for predicted values. If multiple evaluation periods are provided, the grid layout wraps after ``max_cols`` columns. .. math:: \hat{y}_{t+i} = f\Bigl( X_{\text{static}},\;X_{\text{dynamic}},\; X_{\text{future}}\Bigr) for :math:`i = 1, \dots, N`, where :math:`N` is the forecast horizon. Parameters ---------- forecast_df : pandas.DataFrame DataFrame containing forecast results with a time column, spatial coordinates, and prediction columns. dt_col : str Name of the time column used to filter forecast results (e.g. ``"year"``). tname : str Target variable name used to construct forecast columns (e.g. ``"subsidence"``). This argument is required. eval_periods : scalar or list, optional Evaluation period(s) used to select forecast results. If set to ``None``, the function selects up to three unique periods from ``test_data[dt_col]``. mode : str, optional Forecast mode. Must be either ``"quantile"`` or ``"point"``. Default is ``"quantile"``. kind : str, optional Type of visualization. If ``"spatial"``, spatial columns are required; otherwise, the provided `x` and `y` columns are used. x : str, optional Column name for the x-axis. For non-spatial plots, this must be provided or will be inferred via ``assert_xy_in``. y : str, optional Column name for the y-axis. For non-spatial plots, this must be provided or will be inferred via ``assert_xy_in``. cmap : str, optional Colormap used for scatter plots. Default is ``"coolwarm"``. max_cols : int, optional Maximum number of evaluation periods to plot per row. If the number of periods exceeds ``max_cols``, a new row is started. axis: str, optional, Wether to keep the axis of set it to False. show_grid: bool, default=True, Visualize the grid grid_props: dict, optional Grid properties for visualizations. If none the properties is infered as ``{"linestyle":":", 'alpha':0.7}``. verbose : int, optional Verbosity level. Controls the amount of output printed. Returns ------- None The function displays the visualization plot. Examples -------- Example 1: **Spatial Visualization** In this example, we visualize the forecasted and actual values of the **subsidence** target variable, using **longitude** and **latitude** for the spatial coordinates. We visualize the results for two evaluation periods (2023 and 2024), using **quantile** mode for the forecast. >>> from fusionlab.plot.forecast import visualize_forecasts >>> forecast_results = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence_q50': [0.3, 0.4, 0.5], >>> 'subsidence': [0.35, 0.42, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> test_data = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence': [0.35, 0.41, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> visualize_forecasts( >>> forecast_df=forecast_results, >>> test_data=test_data, >>> dt_col="date", >>> tname="subsidence", >>> eval_periods=[2023, 2024], >>> mode="quantile", >>> kind="spatial", >>> cmap="coolwarm", >>> max_cols=2, >>> verbose=1 >>> ) Example 2: **Non-Spatial Visualization** In this example, we visualize the forecasted and actual values of the **subsidence** target variable in a **non-spatial** context. The columns `longitude` and `latitude` are still provided but used for non-spatial x and y axes. Evaluation is for 2023. >>> from fusionlab.plot.forecast import visualize_forecasts >>> forecast_results = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence_pred': [0.35, 0.41, 0.48], >>> 'subsidence': [0.36, 0.43, 0.49], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> test_data = pd.DataFrame({ >>> 'longitude': [-103.808151, -103.808151, -103.808151], >>> 'latitude': [0.473152, 0.473152, 0.473152], >>> 'subsidence': [0.36, 0.42, 0.50], >>> 'date': ['2023-01-01', '2023-01-02', '2023-01-03'] >>> }) >>> forecast_df_point = visualize_forecasts( >>> forecast_df=forecast_results, >>> test_data=test_data, >>> dt_col="date", >>> tname="subsidence", >>> eval_periods=[2023], >>> mode="point", >>> kind="non-spatial", >>> x="longitude", >>> y="latitude", >>> cmap="viridis", >>> max_cols=1, >>> axis="off", >>> show_grid=True, >>> grid_props={"linestyle": "--", "alpha": 0.5}, >>> verbose=2 >>> ) Notes ----- - In ``quantile`` mode, the function uses the column ``<tname>_q50`` for visualization. - In ``point`` mode, the column ``<tname>_pred`` is used. - For spatial visualizations, if ``x`` and ``y`` are not provided, they default to ``"longitude"`` and ``"latitude"``. - The evaluation period(s) are determined by filtering ``forecast_df[dt_col] == <eval_period>``. - Use ``assert_xy_in`` to validate that the x and y columns exist in the provided DataFrames. See Also -------- generate_forecast : Function to generate forecast results. coverage_score : Function to compute the coverage score. References ---------- .. [1] Kouadio L. et al., "Gofast Forecasting Model", Journal of Advanced Forecasting, 2025. (In review) """ # ****************************************************** from ..utils.ts_utils import filter_by_period # ********************************************************* # Check that forecast_df is a valid DataFrame is_frame ( forecast_df, df_only=True, objname="Forecast data", error="raise" ) if eval_periods is None: unique_periods = sorted(forecast_df[dt_col].unique()) if verbose: print("No eval_period provided; using up to three unique " + "periods from forecast data.") eval_periods = unique_periods[:3] eval_periods = columns_manager(eval_periods, to_string=True ) # Check if test_data is provided, else set it to None if test_data is not None: is_frame ( test_data, df_only=True, objname="Test data", error="raise" ) # filterby periods ensure Ensure dt_col is in Pandas datetime format test_data =filter_by_period (test_data, eval_periods, dt_col) forecast_df =filter_by_period (forecast_df, eval_periods, dt_col) # Convert eval_periods to Pandas datetime64[ns] format # # Ensure dtype match before filtering eval_periods= forecast_df[dt_col].astype(str).unique() # Determine x and y columns for spatial or non-spatial visualization if kind == "spatial": if x is None and y is None: x, y = "longitude", "latitude" check_spatial_columns(forecast_df, spatial_cols=(x, y )) x, y = assert_xy_in(x, y, data=forecast_df, asarray=False) else: if x is None or y is None: raise ValueError("For non-spatial kind, both x and y must be provided.") x, y = assert_xy_in(x, y, data=forecast_df, asarray=False) # Set prediction column based on forecast mode if mode == "quantile": pred_col = f"{tname}_q50" pred_label = f"Predicted {tname} (q50)" elif mode == "point": pred_col = f"{tname}_pred" pred_label = f"Predicted {tname}" else: raise ValueError("Mode must be either 'quantile' or 'point'.") # XXX # restore back to origin_dtype before # Loop over evaluation periods and plot df_actual = test_data if test_data is not None else forecast_df actual_name = get_actual_column_name ( df_actual, tname, actual_name=actual_name , default_to='tname', ) # Compute global min-max for color scale # for all plot. vmin = forecast_df[pred_col].min() vmax = forecast_df[pred_col].max() if test_data is not None and actual_name in test_data.columns: vmin = min(vmin, test_data[actual_name].min()) vmax = max(vmax, test_data[actual_name].max()) # Determine common periods in both forecast_df and test_data (if available) if test_data is not None: available_periods = is_in_if ( forecast_df[dt_col].astype(str).unique(), test_data[dt_col].astype(str).unique(), return_intersect=True, ) # available_periods = sorted(set(forecast_df[dt_col]) & set(test_data[dt_col])) else: # sorted(forecast_df[dt_col].astype(str).unique()) available_periods = eval_periods # Ensure the eval_periods only contain periods available in the data eval_periods = [p for p in eval_periods if p in available_periods] if len(eval_periods) == 0: raise ValueError( "[ERROR] No valid evaluation periods found in forecast or test data.") # Compute subplot grid dimensions n_periods = len(eval_periods) n_cols = min(n_periods, max_cols) n_rows = int(np.ceil(n_periods / max_cols)) # Two rows per evaluation period if # test_data is passed or is not empty total_rows = n_rows * 2 if test_data is not None else n_rows # Create subplot grid fig, axes = plt.subplots( total_rows, n_cols, figsize=(5 * n_cols, 4 * total_rows) ) # Ensure `axes` is a 2D array for consistent indexing if total_rows == 1 and n_cols == 1: pass # axes = np.array([[axes]]) elif total_rows == 1: axes = np.array([axes]) elif n_cols == 1: axes = axes.reshape(total_rows, 1) for idx, period in enumerate(eval_periods): # Try to reconvert col_idx = idx % n_cols row_idx = (idx // n_cols) * 2 if test_data is not None else (idx // n_cols) # Filter data for the current period using 'isin' for robustness forecast_subset = forecast_df[forecast_df[dt_col].isin([period])] if test_data is not None: test_subset = test_data[test_data[dt_col].isin([period])] # test_subset =filter_by_period (test_data, period, dt_col) else: test_subset = forecast_subset # If no test_data, use forecast_df itself if forecast_subset.empty or test_subset.empty: if verbose: print(f"[WARNING] No data for period {period}; skipping.") continue # Plot actual values if test_data is not None: ax_actual = axes[row_idx, col_idx] sc_actual = ax_actual.scatter( test_subset[x.name], test_subset[y.name], c=test_subset[actual_name], cmap=cmap, alpha=0.7, edgecolors='k', s=s, vmin=vmin, vmax=vmax, **kw ) ax_actual.set_title(f"Actual {tname.capitalize()} ({period})") ax_actual.set_xlabel(x.name.capitalize()) ax_actual.set_ylabel(y.name.capitalize()) if axis == "off": ax_actual.set_axis_off() else: ax_actual.set_axis_on() fig.colorbar(sc_actual, ax=ax_actual, label=tname.capitalize()) if show_grid: if grid_props is None: grid_props = {"linestyle": ":", 'alpha': 0.7} ax_actual.grid(True, **grid_props) # Plot predicted values ax_pred = axes[row_idx + 1, col_idx ] if test_data is not None else axes[row_idx, col_idx] # ax_pred = axes[row_idx + 1, col_idx] sc_pred = ax_pred.scatter( forecast_subset[x.name], forecast_subset[y.name], c=forecast_subset[pred_col], cmap=cmap, alpha=0.7, edgecolors='k', s=s, vmin=vmin, # Apply global min vmax=vmax, # Apply global max **kw ) ax_pred.set_title(f"{pred_label} ({period})") ax_pred.set_xlabel(x.name.capitalize()) ax_pred.set_ylabel(y.name.capitalize()) if axis == "off": ax_pred.set_axis_off() else: ax_pred.set_axis_on() if show_grid: if grid_props is None: grid_props = {"linestyle": ":", 'alpha': 0.7} ax_pred.grid(True, **grid_props) fig.colorbar(sc_pred, ax=ax_pred, label=pred_label) # 4. Save figure to disk if requested. plt.tight_layout() if savefig: save_figure ( fig, savefile = savefig, save_fmts= save_fmts, dpi=300, bbox_inches="tight" ) plt.close(fig) else: plt.show()
def _get_metrics_from_cols( columns: List[str], prefixes: List[str] ) -> List[str]: """Extracts metric suffixes (e.g., q10, actual) from column names.""" metrics = set() for col in columns: for p in prefixes: if col.startswith(p + '_'): metrics.add(col[len(p)+1:]) return sorted(list(metrics)) def _parse_wide_df_columns( df_wide: pd.DataFrame, value_prefixes: List[str] ) -> Dict[str, Dict[str, Dict[str, str]]]: """Parses wide-format columns into a structured dictionary.""" plot_structure = {prefix: {} for prefix in value_prefixes} # Regex for columns with years, e.g., GWL_2022_q10 or GWL_2022_actual pattern_year = re.compile( r'(' + '|'.join(re.escape(p) for p in value_prefixes) + r')_(\d{4})_?(.*)' ) # Regex for columns without years, e.g., GWL_q10 or GWL_pred pattern_no_year = re.compile( r'(' + '|'.join(re.escape(p) for p in value_prefixes) + r')_([a-zA-Z].*)' ) for col in df_wide.columns: match_year = pattern_year.match(col) match_no_year = pattern_no_year.match(col) if match_year: prefix, year, suffix = match_year.groups() # If suffix is empty, it's a point prediction suffix = suffix or 'pred' if year not in plot_structure[prefix]: plot_structure[prefix][year] = {} plot_structure[prefix][year][suffix] = col elif match_no_year: # Handles columns like 'subsidence_q10', # 'subsidence_actual' prefix, suffix = match_no_year.groups() if "static" not in plot_structure[prefix]: plot_structure[prefix]["static"] = {} plot_structure[prefix]["static"][suffix] = col elif any(col == p for p in value_prefixes): # Handles base prefix column e.g. 'subsidence' if "static" not in plot_structure[col]: plot_structure[col]["static"] ={} plot_structure[col]['static']['pred'] = col return plot_structure def _plot_spatial_subplot(ax, df, x_col, y_col, c_col, s= 10, **kwargs): """Helper to create a single scatter subplot.""" if c_col is None or c_col not in df.columns: ax.set_title(f"{kwargs.get('title', '')}\n(Data not found)", fontsize=10, color='red') ax.axis('off') return None plot_df = df[[x_col, y_col, c_col]].dropna() if plot_df.empty: ax.set_title(f"{kwargs.get('title', '')}\n(No valid data)", fontsize=10) ax.axis('off') return None scatter = ax.scatter( plot_df[x_col], plot_df[y_col], c=plot_df[c_col], cmap=kwargs.get('cmap'), vmin=kwargs.get('vmin'), vmax=kwargs.get('vmax'), s=s, edgecolors='k', linewidths=0.1, alpha=0.8 ) ax.set_title(kwargs.get('title', ''), fontsize=10) if kwargs.get('axis_off'): ax.axis('off') else: ax.tick_params(axis='both', which='major', labelsize=8) if kwargs.get('show_grid'): ax.grid(**kwargs.get('grid_props')) return scatter def _plot_temporal_subplot(ax, df, value_col, title, **kwargs): """Helper to create a single line plot (fallback).""" if value_col is None or value_col not in df.columns: ax.set_title(f"{title}\n(Data not found)", fontsize=10, color='red') ax.axis('off') return plot_df = df[[value_col]].dropna() if plot_df.empty: ax.set_title(f"{title}\n(No valid data)", fontsize=10) ax.axis('off') return ax.plot(plot_df.index, plot_df[value_col]) ax.set_title(title, fontsize=10) if not kwargs.get('axis_off'): ax.tick_params(axis='both', which='major', labelsize=8) if kwargs.get('show_grid'): ax.grid(**kwargs.get('grid_props')) def _plot_forecast_grid( fig, axes, df_wide, plot_structure, years_to_plot, quantiles_to_plot, prefix, kind, spatial_cols, s, plot_kwargs, _logger ): # Check if spatial coordinates are available for plotting. has_spatial_coords = all( c in df_wide.columns for c in spatial_cols ) # Log a fallback message if spatial coordinates are not found. if not has_spatial_coords: vlog( f"Spatial columns {spatial_cols} not found. " "Falling back to temporal line plots.", level=1, verbose=plot_kwargs.get('verbose', 0), logger=_logger ) # Get the last known static actual column for fallback. last_known_actual_col = plot_structure[prefix].get( "static", {}).get("actual") # Iterate through each year to create a row of plots. for row_idx, year in enumerate(years_to_plot): ax_row = axes[row_idx] col_idx = 0 # Plot actual data if 'dual' mode is enabled. if kind == 'dual': actual_col = plot_structure[prefix].get( year, {}).get('actual', last_known_actual_col) if actual_col: last_known_actual_col = actual_col title = f"Actual ({year})" plot_kwargs['title'] = title # Choose plotting function based on coordinate availability. if has_spatial_coords: _plot_spatial_subplot( ax_row[col_idx], df_wide, *spatial_cols, actual_col, s =s, **plot_kwargs ) else: _plot_temporal_subplot( ax_row[col_idx], df_wide, actual_col, **plot_kwargs ) col_idx += 1 # Plot each requested prediction/quantile. for q_suffix in quantiles_to_plot: if col_idx >= len(ax_row): continue pred_col = plot_structure[prefix].get(year, {}).get(q_suffix) # Format the title for the subplot. title_suffix = ( q_suffix.replace('q', 'Q').capitalize() if 'q' in q_suffix else "Prediction" ) title = f"{title_suffix} ({year})" plot_kwargs['title'] = title # Choose plotting function based on coordinate availability. if has_spatial_coords: _plot_spatial_subplot( ax_row[col_idx], df_wide, *spatial_cols, pred_col, s=s, **plot_kwargs ) else: _plot_temporal_subplot( ax_row[col_idx], df_wide, pred_col, **plot_kwargs ) col_idx += 1 # Turn off any remaining unused axes in the current row. for i in range(col_idx, len(ax_row)): ax_row[i].axis('off') def _plot_single_scatter( ax, df, x_col, y_col, c_col, cmap, vmin, vmax, title, axis_off, show_grid, grid_props ): """Helper to create a single scatter subplot.""" if c_col not in df.columns: ax.set_title(f"{title}\n(Data not found)", fontsize=10, color='red') ax.set_xticks([]) ax.set_yticks([]) return None # Drop NaNs for this specific plot to avoid plotting empty points plot_df = df[[x_col, y_col, c_col]].dropna() if plot_df.empty: ax.set_title(f"{title}\n(No valid data)", fontsize=10) ax.set_xticks([]) ax.set_yticks([]) return None scatter = ax.scatter( plot_df[x_col], plot_df[y_col], c=plot_df[c_col], cmap=cmap, vmin=vmin, vmax=vmax, s=10 ) ax.set_title(title, fontsize=10) if axis_off: ax.axis('off') else: ax.tick_params(axis='both', which='major', labelsize=8) if show_grid: ax.grid(**grid_props) return scatter