fusionlab.plot.forecast.plot_forecasts

fusionlab.plot.forecast.plot_forecasts(forecast_df, target_name='target', quantiles=None, output_dim=1, kind='temporal', actual_data=None, dt_col=None, actual_target_name=None, sample_ids='first_n', num_samples=3, horizon_steps=1, spatial_cols=None, max_cols=2, figsize=(8, 4.5), scaler=None, scaler_feature_names=None, target_idx_in_scaler=None, titles=None, cbar=None, step_names=None, show_grid=True, grid_props=None, savefig=None, save_fmts='.png', show=True, verbose=0, _logger=None, **plot_kwargs)[source]

Visualizes model forecasts from a structured DataFrame.

This function generates plots to visualize time series forecasts, supporting temporal line plots for individual samples/items and spatial scatter plots for specific forecast horizons. It can handle point forecasts and quantile (probabilistic) forecasts, optionally overlaying actual values for comparison. Predictions and actuals can be inverse-transformed if a scaler is provided.

The input forecast_df is expected to be in a long format, typically generated by format_predictions_to_dataframe(), containing ‘sample_idx’ and ‘forecast_step’ columns, along with prediction columns (e.g., ‘{target_name}_pred’, ‘{target_name}_q50’) and optionally actual value columns (e.g., ‘{target_name}_actual’).

Parameters:
  • forecast_df (pd.DataFrame) – A pandas DataFrame containing the forecast data in long format. Must include ‘sample_idx’ and ‘forecast_step’ columns, along with prediction columns (and optionally actuals and spatial coordinate columns).

  • target_name (str, default "target") – The base name of the target variable. This is used to construct column names for predictions and actuals (e.g., “target_pred”, “target_q50”, “target_actual”).

  • quantiles (List[float], optional) – A list of quantiles that were predicted (e.g., [0.1, 0.5, 0.9]). If provided, the function will plot the median quantile as a line and the range between the first and last quantile as a shaded uncertainty interval (for temporal plots). If None, a point forecast is assumed. Default is None.

  • output_dim (int, default 1) – The number of target variables (output dimensions) predicted at each time step. If > 1, separate subplots or plot groups may be generated for each output dimension.

  • kind ({'temporal', 'spatial'}, default "temporal") –

    The type of plot to generate: - 'temporal': Plots forecasts over the horizon for selected

    samples (time series line plots).

    • 'spatial': Creates scatter plots of forecast values over spatial coordinates for selected horizon steps. Requires spatial_cols to be specified.

  • actual_data (pd.DataFrame, optional) – An optional DataFrame containing the true actual values for comparison. If forecast_df already contains actual value columns (e.g., ‘{target_name}_actual’), this may not be needed or can be used to supplement. If provided, dt_col and actual_target_name might be used for alignment or direct plotting. Note: Current implementation primarily uses actuals from `forecast_df`.

  • dt_col (str, optional) – Name of the datetime column in actual_data or forecast_df if needed for x-axis labeling in temporal plots. If forecast_df uses ‘forecast_step’, this might be less critical unless aligning with a true date axis from actual_data.

  • actual_target_name (str, optional) – The name of the target column in actual_data if it differs from target_name. If None, target_name is assumed.

  • sample_ids (int, List[int], str, default "first_n") – Specifies which samples (based on ‘sample_idx’ in forecast_df) to plot for kind=’temporal’. - If int: Plots the sample at that specific index. - If List[int]: Plots all samples with these indices. - If "first_n": Plots the first num_samples. - If "all": Plots all unique samples (can be many plots).

  • num_samples (int, default 3) – Number of samples to plot if sample_ids=”first_n”.

  • horizon_steps (int, List[int], str, default 1) – Specifies which forecast horizon step(s) to plot for kind=’spatial’. - If int: Plots the specified single step (1-indexed). - If List[int]: Plots all specified steps. - If "all" or None: Plots all unique forecast steps available.

  • spatial_cols (List[str], optional) – Required if kind=’spatial’. A list of two column names from forecast_df representing the x and y spatial coordinates (e.g., [‘longitude’, ‘latitude’]). Default is None.

  • max_cols (int, default 2) – Maximum number of subplots to arrange per row in the figure.

  • figsize (Tuple[float, float], default (8, 4.5)) – The size (width, height) in inches for each subplot. The total figure size will be inferred based on the number of rows and columns of subplots.

  • scaler (Any, optional) – A fitted scikit-learn-like scaler object (must have an inverse_transform method). If provided along with scaler_feature_names and target_idx_in_scaler, prediction and actual columns related to target_name will be inverse-transformed before plotting. Default is None.

  • scaler_feature_names (List[str], optional) – A list of all feature names (in order) that the scaler was originally fit on. Required for targeted inverse transform if scaler is provided. Default is None.

  • target_idx_in_scaler (int, optional) – The index of the target_name (or the specific output dimension being plotted) within the scaler_feature_names list. Required for targeted inverse transform if scaler is provided. Default is None.

  • titles (List[str], optional) – A list of custom titles for the subplots. If provided, its length should match the number of subplots generated. Default is None (titles are auto-generated).

  • cbar ({'uniform', None}, default None) – Controls colour‑bar behaviour for spatial plots. Use 'uniform' to compute a single global vmin / vmax and attach one shared colour‑bar to the figure. Any other value (or None) lets each subplot auto‑scale its own colour‑bar.

  • step_names (dict[int, Any] or None, default None) – Optional mapping that renames forecast‑step integers when building subplot titles—e.g. {1: "2021", 2: "2022"}. Keys are coerced to int; values are converted to str. Missing steps fall back to default_name ("" is the step number itself).

  • show_grid (bool, default True) – Toggle the display of ax.grid in both temporal and spatial plots. Set to False for a cleaner look when overlaying graphics on a map.

  • grid_props (dict or None, default None) – Keyword arguments forwarded to :pyfunc:`matplotlib.axes.Axes.grid` (e.g. {"linestyle": ":", "alpha": 0.6}). Ignored when show_grid is False.

  • verbose (int, default 0) – Verbosity level for logging during plot generation. - 0: Silent. - 1 or higher: Print informational messages.

  • **plot_kwargs (Any) – Additional keyword arguments to pass to the underlying Matplotlib plotting functions (e.g., ax.plot, ax.scatter, ax.fill_between). Can be used to customize line styles, colors, marker sizes, etc. For example, median_plot_kwargs={‘color’: ‘blue’}, fill_between_kwargs={‘color’: ‘lightblue’}.

  • savefig (str | None)

  • save_fmts (str | List[str] | None)

  • show (bool)

  • _logger (Logger | Callable[[str], None] | None)

  • **plot_kwargs

Returns:

This function directly generates and shows/saves plots using Matplotlib and does not return any value.

Return type:

None

Raises:
  • TypeError – If forecast_df is not a pandas DataFrame.

  • ValueError – If essential columns like ‘sample_idx’ or ‘forecast_step’ are missing from forecast_df. If kind=’spatial’ is chosen but spatial_cols are not provided or not found in forecast_df. If an unsupported kind is specified.

See also

fusionlab.nn.utils.format_predictions_to_dataframe

Utility to generate the forecast_df expected by this function.

matplotlib.pyplot.plot

Underlying plotting function for temporal forecasts.

matplotlib.pyplot.scatter

Underlying plotting function for spatial forecasts.

Examples

>>> from fusionlab.nn.utils import format_predictions_to_dataframe
>>> from fusionlab.plot.forecast import plot_forecasts
>>> import pandas as pd
>>> import numpy as np
>>> # Assume preds_point (B,H,O) and preds_quant (B,H,Q) are available
>>> B, H, O, Q_len = 4, 3, 1, 3
>>> preds_point = np.random.rand(B, H, O)
>>> preds_quant = np.random.rand(B, H, Q_len)
>>> y_true_seq = np.random.rand(B, H, O)
>>> quantiles_list = [0.1, 0.5, 0.9]
>>> # Create DataFrames using format_predictions_to_dataframe
>>> df_point = format_predictions_to_dataframe(
...     predictions=preds_point, y_true_sequences=y_true_seq,
...     target_name="value", forecast_horizon=H, output_dim=O
... )
>>> df_quant = format_predictions_to_dataframe(
...     predictions=preds_quant, y_true_sequences=y_true_seq,
...     target_name="value", quantiles=quantiles_list,
...     forecast_horizon=H, output_dim=O
... )
>>> # Example 1: Plot temporal point forecast for first sample
>>> # plot_forecasts(df_point, target_name="value", sample_ids=0)
>>> # Example 2: Plot temporal quantile forecast for first 2 samples
>>> # plot_forecasts(df_quant, target_name="value",
... #                quantiles=quantiles_list, sample_ids="first_n",
... #                num_samples=2, max_cols=1)
>>> # Example 3: Spatial plot (requires spatial_cols in df_quant)
>>> # Assume df_quant has 'lon' and 'lat' columns
>>> # df_quant['lon'] = np.random.rand(len(df_quant)) * 10
>>> # df_quant['lat'] = np.random.rand(len(df_quant)) * 5
>>> # plot_forecasts(df_quant, target_name="value",
... #                quantiles=quantiles_list, kind='spatial',
... #                horizon_steps=1, spatial_cols=['lon', 'lat'])