Forecasting Workflow Utilities

fusionlab-learn provides a set of powerful utility functions within the fusionlab.nn.utils module to streamline common tasks in a time series forecasting pipeline. This guide demonstrates a typical workflow using three key utilities:

  1. prepare_model_inputs(): Standardizes the creation of the input list [static, dynamic, future] for various model types, handling optional inputs gracefully.

  2. format_predictions_to_dataframe(): Transforms raw model predictions (point or quantile) into a structured, long-format pandas DataFrame, suitable for analysis, storage, and further visualization.

  3. plot_forecasts(): Visualizes the formatted forecast DataFrame, allowing comparison of predictions against actuals in both temporal and spatial dimensions. (Note: This function resides in `fusionlab.plot.evaluation` but is often used in conjunction with `nn.utils`).

By using these utilities together, you can significantly simplify your forecasting code, making it more robust and easier to manage.

Prerequisites

Ensure you have fusionlab-learn and its common dependencies installed. For visualizations, matplotlib is also needed.

pip install fusionlab-learn matplotlib scikit-learn

Common Setup for Examples

We’ll start with common imports and generate some basic dummy data that simulates static, dynamic, and future features, along with target values.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5import os
 6import warnings
 7
 8# FusionLab imports
 9from fusionlab.nn.utils import (
10    prepare_model_inputs,
11    format_predictions_to_dataframe
12)
13try:
14    from fusionlab.plot.forecast import plot_forecasts
15except ImportError:
16    # Fallback if plot_forecasts is in nn.utils for some versions
17    from fusionlab.nn.utils import plot_forecasts
18    warnings.warn("Imported plot_forecasts from fusionlab.nn.utils. "
19                  "Consider moving to fusionlab.plot.evaluation.")
20
21
22# Suppress warnings and TF logs for cleaner output
23warnings.filterwarnings('ignore')
24os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
25tf.get_logger().setLevel('ERROR')
26if hasattr(tf, 'autograph'):
27    tf.autograph.set_verbosity(0)
28
29# Base dimensions for dummy data
30B, T_PAST, H_OUT = 5, 12, 4 # Batch, Past Timesteps, Horizon
31D_S, D_D, D_F, D_O = 2, 3, 2, 1 # Static, Dynamic, Future, Output Dims
32T_FUTURE_TOTAL = T_PAST + H_OUT
33SEED = 42
34np.random.seed(SEED)
35tf.random.set_seed(SEED)
36
37# Generate dummy data components
38raw_static_data = tf.random.normal((B, D_S), dtype=tf.float32, seed=SEED)
39raw_dynamic_data = tf.random.normal((B, T_PAST, D_D), dtype=tf.float32, seed=SEED+1)
40raw_future_data = tf.random.normal((B, T_FUTURE_TOTAL, D_F), dtype=tf.float32, seed=SEED+2)
41raw_y_true_sequences = tf.random.normal((B, H_OUT, D_O), dtype=tf.float32, seed=SEED+3)
42
43# Simulate some spatial identifiers for later use
44spatial_ids_df = pd.DataFrame({
45    'location_id': [f'L{i}' for i in range(B)],
46    'region': [f'R{i%2}' for i in range(B)]
47})
48
49print("Common setup complete. Dummy data generated.")
50print(f"  Static shape : {raw_static_data.shape}")
51print(f"  Dynamic shape: {raw_dynamic_data.shape}")
52print(f"  Future shape : {raw_future_data.shape}")
53print(f"  Target shape : {raw_y_true_sequences.shape}")

Expected Output (Common Setup):

Common setup complete. Dummy data generated.
  Static shape : (5, 2)
  Dynamic shape: (5, 12, 3)
  Future shape : (5, 16, 2)
  Target shape : (5, 4, 1)

Step 1: Preparing Model Inputs with prepare_model_inputs

API Reference:

prepare_model_inputs()

The first step in a forecasting pipeline after loading/generating raw features is to package them correctly for your specific model. prepare_model_inputs helps create the standard 3-element list [static_input, dynamic_input, future_input] that many fusionlab-learn models expect for their call method.

Scenario 1.1: Stricter Model (e.g., XTFT, TFTStricter)

These models typically require all three input types (static, dynamic, future) to be actual tensors. If an input type is not semantically present for your data, prepare_model_inputs with model_type=’strict’ will create a dummy tensor with zero features for that slot.

 1print("\n--- Preparing inputs for a 'strict' model ---")
 2# Example 1: All inputs provided
 3inputs_strict_all = prepare_model_inputs(
 4    dynamic_input=raw_dynamic_data,
 5    static_input=raw_static_data,
 6    future_input=raw_future_data,
 7    model_type='strict',
 8    forecast_horizon=H_OUT, # Used for dummy future if future_input is None
 9    verbose=1
10)
11print(f"Strict (all provided): S={inputs_strict_all[0].shape}, "
12      f"D={inputs_strict_all[1].shape}, F={inputs_strict_all[2].shape}")
13
14# Example 2: Static input is conceptually absent
15inputs_strict_no_static = prepare_model_inputs(
16    dynamic_input=raw_dynamic_data,
17    static_input=None, # Static features are not available
18    future_input=raw_future_data,
19    model_type='strict',
20    forecast_horizon=H_OUT,
21    verbose=1
22)
23print(f"Strict (no static): S={inputs_strict_no_static[0].shape}, "
24      f"D={inputs_strict_no_static[1].shape}, "
25      f"F={inputs_strict_no_static[2].shape}")

Expected Output 1.1:

--- Preparing inputs for a 'strict' model ---
  prepare_model_inputs (strict): Passing inputs as is.
Strict (all provided): S=(5, 2), D=(5, 12, 3), F=(5, 16, 2)
  prepare_model_inputs (strict): Created dummy static input with shape (5, 0)
  prepare_model_inputs (strict): Passing inputs as is.
Strict (no static): S=(5, 0), D=(5, 12, 3), F=(5, 16, 2)
Scenario 1.2: Flexible Model (e.g., TemporalFusionTransformer)

Flexible models can handle None for optional inputs (static, future). prepare_model_inputs with model_type=’flexible’ will pass these None values through.

 1print("\n--- Preparing inputs for a 'flexible' model ---")
 2# Example 1: Dynamic input only
 3inputs_flex_dyn_only = prepare_model_inputs(
 4    dynamic_input=raw_dynamic_data,
 5    static_input=None,
 6    future_input=None,
 7    model_type='flexible',
 8    verbose=1
 9)
10s_shape = inputs_flex_dyn_only[0].shape if inputs_flex_dyn_only[0] is not None else "None"
11d_shape = inputs_flex_dyn_only[1].shape
12f_shape = inputs_flex_dyn_only[2].shape if inputs_flex_dyn_only[2] is not None else "None"
13print(f"Flexible (dynamic only): S={s_shape}, D={d_shape}, F={f_shape}")

Expected Output 1.2:

--- Preparing inputs for a 'flexible' model ---
  prepare_model_inputs (flexible): Passing inputs as is (Static: <class 'NoneType'>, Dynamic: <class 'tensorflow.python.framework.ops.EagerTensor'>, Future: <class 'NoneType'>)
Flexible (dynamic only): S=None, D=(5, 12, 3), F=None

Step 2: Simulate Model Prediction

For this exercise, we won’t train a full model. Instead, we’ll simulate the kind of output a forecasting model might produce. Let’s assume we are doing a quantile forecast.

1# Simulate predictions (e.g., from an XTFT model)
2# Shape: (Batch, Horizon, NumQuantiles * OutputDim)
3# For this example, OutputDim=1, NumQuantiles=3
4simulated_predictions_quant = tf.random.normal(
5    (B, H_OUT, len(Q_LIST_VIZ) * D_O), dtype=tf.float32, seed=SEED+4
6)
7print(f"\nSimulated quantile predictions shape: {simulated_predictions_quant.shape}")

Expected Output 2.1:

Simulated quantile predictions shape: (5, 4, 3)

Step 3: Format Predictions with format_predictions_to_dataframe

API Reference:

format_predictions_to_dataframe()

This utility takes the raw prediction tensor (and optionally actuals, spatial data, etc.) and converts it into a well-structured, long-format pandas DataFrame. This DataFrame is then easy to analyze, save, or pass to plotting functions.

Scenario 3.1: Formatting Quantile Forecasts with Actuals and Spatial Data

 1print("\n--- Formatting quantile predictions to DataFrame ---")
 2# Use the spatial_ids_df created in common setup
 3# Ensure it has the same number of samples (B) as predictions
 4spatial_data_for_format = spatial_ids_df # Shape (B, NumSpatialFeatures)
 5
 6forecast_df_viz = format_predictions_to_dataframe(
 7    predictions=simulated_predictions_quant,
 8    y_true_sequences=raw_y_true_sequences,
 9    target_name="sales", # Base name for columns
10    quantiles=Q_LIST_VIZ,
11    forecast_horizon=H_OUT, # Helps structure the DataFrame
12    output_dim=D_O,         # Number of target variables
13    spatial_data_array=spatial_data_for_format, # DataFrame with B rows
14    spatial_cols_names=['location_id', 'region_code'], # Names for these cols
15    verbose=1
16)
17print("\nFormatted DataFrame head (Quantile Forecast):")
18print(forecast_df_viz.head(H_OUT * 2)) # Show for first two samples
19print(f"\nFormatted DataFrame shape: {forecast_df_viz.shape}")
20print(f"Formatted DataFrame columns: {forecast_df_viz.columns.tolist()}")
Expected Output 3.1:

(DataFrame structure with sample_idx, forecast_step, spatial cols, sales_q10, sales_q50, sales_q90, sales_actual)

--- Formatting quantile predictions to DataFrame ---
[INFO] Starting prediction formatting to DataFrame.
    [INFO]   Raw predictions shape: (5, 4, 3)
    [INFO]   Inferred/Validated: Samples=5, Horizon=4, OutputDim=1, NumQuantiles=3
    [INFO]   Added prediction columns: ['sales_q10', 'sales_q50', 'sales_q90']
    [INFO]   Added actual value columns: ['sales_actual']
[INFO] Prediction formatting to DataFrame complete.

Formatted DataFrame head (Quantile Forecast):
   sample_idx  forecast_step  sales_q10  sales_q50  sales_q90  sales_actual
0           0              1  -0.492519   0.314352  -0.939723     -0.019795
1           0              2  -0.489788   1.087007   0.165282      0.407925
2           0              3   0.692570  -0.101750  -0.165129     -0.115735
3           0              4   0.622007   0.223282   0.049389     -0.308791
4           1              1  -1.499012  -0.228126  -0.840142      0.445111
5           1              2  -0.401215   1.823693   1.008885     -0.407488
6           1              3   1.087821   0.155696  -0.351913      2.175023
7           1              4  -0.040999  -1.583362   1.056865      0.755576

Formatted DataFrame shape: (20, 6)
Formatted DataFrame columns: ['sample_idx', 'forecast_step', 'sales_q10', 'sales_q50', 'sales_q90', 'sales_actual']

Step 4: Visualizing Formatted Predictions with plot_forecasts

API Reference:

plot_forecasts()

Once your predictions are in a structured DataFrame (thanks to format_predictions_to_dataframe), plot_forecasts can easily visualize them.

Scenario 4.1: Temporal Quantile Forecast for Selected Samples

 1print("\n--- Visualizing Temporal Quantile Forecast ---")
 2plot_forecasts(
 3    forecast_df=forecast_df_viz,
 4    target_name="sales",
 5    quantiles=Q_LIST_VIZ,
 6    output_dim=D_O,
 7    kind="temporal",
 8    sample_ids=[0, 1], # Plot for first two samples
 9    max_cols=1,         # Each sample plot in a new row
10    figsize_per_subplot=(10, 4),
11    verbose=1
12)
13# To save:
14# fig_path = os.path.join(evaluation_plot_dir, "workflow_temporal_quantile.png")
15# plt.savefig(fig_path)
Expected Plot 4.1:

(Two subplots, each showing actual vs. median and prediction interval for sample_idx 0 and 1 respectively)

Temporal Quantile Forecast from Workflow Utilities

Temporal plot showing actuals, median forecast, and prediction intervals for selected samples.

Scenario 4.2: Spatial Point Forecast for a Specific Horizon Step

First, let’s create a point forecast DataFrame for this.

 1# Simulate point predictions (e.g., just the median from quantiles)
 2simulated_predictions_point = simulated_predictions_quant[:, :, 1:2] # Take median
 3
 4forecast_df_point_for_spatial = format_predictions_to_dataframe(
 5    predictions=simulated_predictions_point,
 6    y_true_sequences=raw_y_true_sequences,
 7    target_name="sales",
 8    # No quantiles for point forecast
 9    forecast_horizon=H_OUT,
10    output_dim=D_O,
11    spatial_data_array=spatial_ids_df,
12    spatial_cols_names=['location_id', 'region_code'],
13    verbose=0
14)
15# Add dummy longitude/latitude for spatial plotting
16# In a real case, these would come from your spatial_data_array
17
18# 1. Work out how many rows the DF actually contains
19n_rows = len(forecast_df_point_for_spatial)      # → B * H_OUT (= 20)
20
21# 2. Create a base vector of length B (one per sample)
22base_lon = np.linspace(-100, -90, B)             #  [-100 … -90] 5 points
23base_lat = np.linspace(30, 35,  B)               #   [30 … 35]   5 points
24
25# -------------------------------------------------------
26# 3. Repeat each value H_OUT times so the final length is n_rows
27forecast_df_point_for_spatial["longitude"] = np.repeat(base_lon, H_OUT)
28forecast_df_point_for_spatial["latitude"]  = np.repeat(base_lat, H_OUT)
29# -------------------------------------------------------
30
31# If you prefer to keep the tile idiom you can do:
32# forecast_df_point_for_spatial["longitude"] = np.tile(base_lon, H_OUT)
33# forecast_df_point_for_spatial["latitude"]  = np.tile(base_lat,  H_OUT)
34
35print("\n--- Visualizing Spatial Point Forecast ---")
36plot_forecasts(
37    forecast_df=forecast_df_point_for_spatial,
38    target_name="sales",
39    # No quantiles
40    output_dim=D_O,
41    kind="spatial",
42    horizon_steps=1, # Plot the first forecast step
43    spatial_cols=['longitude', 'latitude'],
44    figsize_per_subplot=(7, 6),
45    verbose=1,
46    # Additional kwargs for scatter plot
47    s=50, cmap='coolwarm' # Marker size and colormap
48)
49# To save:
50# fig_path = os.path.join(evaluation_plot_dir, "workflow_spatial_point.png")
51# plt.savefig(fig_path)
Expected Plot 4.2:

(A scatter plot showing predicted ‘sales_pred’ values at different longitude/latitude points for the first forecast horizon step.)

Spatial Point Forecast from Workflow Utilities

Spatial plot showing point forecast values across coordinates for a specific horizon step.


Conclusion

This guide demonstrated a streamlined workflow using key utilities from fusionlab.nn.utils and fusionlab.plot.evaluation:

  • `prepare_model_inputs` helps in correctly structuring the potentially complex list of inputs (static, dynamic, future) that forecasting models require, handling optional inputs gracefully.

  • `format_predictions_to_dataframe` transforms raw model outputs (point or quantile, single or multi-target) into a standardized long-format DataFrame, which is essential for systematic analysis, storage, and as input to other evaluation tools.

  • `plot_forecasts` offers a versatile way to quickly visualize these formatted predictions, allowing for temporal inspection of individual series and spatial distribution of forecasts.

By leveraging these functions, users can significantly reduce boilerplate code, ensure data consistency, and focus more on model development and interpretation. For more detailed evaluation metrics, please refer to the metrics page.