Quantile Forecasting with TFT Variants

This guide demonstrates how to configure and train Temporal Fusion Transformer (TFT) models available in fusionlab-learn to produce quantile forecasts. Instead of predicting a single point value, the model predicts multiple quantiles (e.g., 10th, 50th, 90th percentiles), providing an estimate of the prediction uncertainty.

We will show examples using both:

  1. The flexible TemporalFusionTransformer (handling optional inputs, demonstrated with dynamic inputs only).

  2. The stricter TFT (requiring all static, dynamic, and future inputs).

Prerequisites

Ensure you have fusionlab-learn and its dependencies installed:

pip install fusionlab-learn matplotlib

Example 1: Quantile Forecasting with Flexible TemporalFusionTransformer

This example uses only dynamic (past observed) features and modifies the model to output quantile predictions for multiple steps ahead.

Workflow:

  1. Generate simple synthetic time series data.

  2. Prepare sequences and multi-step targets using create_sequences().

  3. Instantiate the flexible TemporalFusionTransformer with specified quantiles and output_dim.

  4. Compile the model using combined_quantile_loss().

  5. Train the model.

  6. Interpret and visualize the multi-quantile output.

Step 1.1: Imports and Setup

Import standard libraries and fusionlab components.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5import warnings
 6import os
 7
 8# FusionLab imports
 9from fusionlab.nn.transformers import TemporalFusionTransformer
10from fusionlab.nn.utils import create_sequences
11from fusionlab.nn.losses import combined_quantile_loss
12
13# Suppress warnings and TF logs
14warnings.filterwarnings('ignore')
15tf.get_logger().setLevel('ERROR')
16if hasattr(tf, 'autograph'):
17    tf.autograph.set_verbosity(0)
18print("Libraries imported for Flexible TFT Quantile Example.")

Step 1.2: Generate Synthetic Data

A simple sine wave with noise serves as our univariate time series.

1time_flex = np.arange(0, 100, 0.1)
2amplitude_flex = np.sin(time_flex) + np.random.normal(
3    0, 0.15, len(time_flex)
4    )
5df_flex = pd.DataFrame({'Value': amplitude_flex})
6print(f"Generated data shape for flexible TFT: {df_flex.shape}")

Step 1.3: Prepare Sequences for Multi-Step Forecasting

We use past observations to predict multiple future steps. Targets are reshaped to (Samples, Horizon, OutputDim).

 1sequence_length_flex = 10
 2forecast_horizon_flex = 5 # Predict next 5 steps
 3target_col_flex = 'Value'
 4
 5sequences_flex, targets_flex = create_sequences(
 6    df=df_flex,
 7    sequence_length=sequence_length_flex,
 8    target_col=target_col_flex,
 9    forecast_horizon=forecast_horizon_flex,
10    verbose=0
11)
12sequences_flex = sequences_flex.astype(np.float32)
13targets_flex = targets_flex.reshape(
14    -1, forecast_horizon_flex, 1 # OutputDim = 1
15    ).astype(np.float32)
16
17print(f"\nFlexible TFT - Input sequences shape (X): {sequences_flex.shape}")
18print(f"Flexible TFT - Target values shape (y): {targets_flex.shape}")

Step 1.4: Define Flexible TFT Model for Quantile Forecast

Instantiate TemporalFusionTransformer, providing the quantiles list. Static and future input dimensions default to None.

 1quantiles_to_predict = [0.1, 0.5, 0.9] # 10th, 50th, 90th
 2num_dynamic_features_flex = sequences_flex.shape[-1]
 3
 4model_flex = TemporalFusionTransformer(
 5    dynamic_input_dim=num_dynamic_features_flex,
 6    # static_input_dim=None, # Default
 7    # future_input_dim=None, # Default
 8    forecast_horizon=forecast_horizon_flex,
 9    output_dim=1, # Univariate target
10    hidden_units=16, num_heads=2,
11    quantiles=quantiles_to_predict, # Enable quantile output
12    num_lstm_layers=1, lstm_units=16
13)
14print("\nFlexible TFT for quantiles instantiated.")
15
16# Compile with combined_quantile_loss
17loss_fn_flex = combined_quantile_loss(quantiles=quantiles_to_predict)
18model_flex.compile(optimizer='adam', loss=loss_fn_flex)
19print("Flexible TFT compiled with quantile loss.")

Step 1.5: Train the Model

Inputs are passed as [None, dynamic_sequences, None] to match the [static, dynamic, future] order.

 1# Order: [Static, Dynamic, Future]
 2train_inputs_flex = sequences_flex # or  [sequences_flex] # for single dynamic tensor
 3
 4print("\nStarting flexible TFT training (quantile)...")
 5history_flex = model_flex.fit(
 6    train_inputs_flex,
 7    targets_flex,
 8    epochs=5, batch_size=32, validation_split=0.2, verbose=0
 9)
10print("Flexible TFT training finished.")
11if history_flex and history_flex.history.get('val_loss'):
12    val_loss = history_flex.history['val_loss'][-1]
13    print(f"Final validation loss (quantile): {val_loss:.4f}")

Step 1.6: Make and Visualize Quantile Predictions

Predictions will have a shape (Batch, Horizon, NumQuantiles). We visualize the median and the prediction interval.

 1num_samples_flex = sequences_flex.shape[0]
 2val_start_idx_flex = int(num_samples_flex * (1 - 0.2))
 3val_dynamic_inputs_flex = sequences_flex[val_start_idx_flex:]
 4val_actuals_flex = targets_flex[val_start_idx_flex:]
 5
 6val_inputs_list_flex = [val_dynamic_inputs_flex]
 7
 8print("\nMaking quantile predictions (flexible TFT)...")
 9val_predictions_quantiles = model_flex.predict(
10    val_inputs_list_flex, verbose=0
11    )
12print(f"Prediction output shape: {val_predictions_quantiles.shape}")
13
14# Visualization for one sample
15sample_to_plot_flex = 0
16actual_vals_flex = val_actuals_flex[sample_to_plot_flex, :, 0]
17pred_quantiles_flex = val_predictions_quantiles[sample_to_plot_flex, :, :]
18
19plot_time_axis_flex = time_flex[
20    val_start_idx_flex + sequence_length_flex + sample_to_plot_flex : \
21    val_start_idx_flex + sequence_length_flex + \
22        sample_to_plot_flex + forecast_horizon_flex
23    ]
24
25plt.figure(figsize=(12, 6))
26plt.plot(plot_time_axis_flex, actual_vals_flex,
27         label='Actual Value', marker='o', linestyle='--')
28plt.plot(plot_time_axis_flex, pred_quantiles_flex[:, 1], # Median (0.5)
29         label='Predicted Median (q=0.5)', marker='x')
30plt.fill_between(
31    plot_time_axis_flex,
32    pred_quantiles_flex[:, 0], # Lower quantile (q=0.1)
33    pred_quantiles_flex[:, 2], # Upper quantile (q=0.9)
34    color='gray', alpha=0.3,
35    label='Prediction Interval (q=0.1 to q=0.9)'
36)
37plt.title('Flexible TFT Quantile Forecast (Dynamic Inputs Only)')
38plt.xlabel('Time'); plt.ylabel('Value')
39plt.legend(); plt.grid(True); plt.tight_layout()
40# plt.savefig("docs/source/images/forecasting_quantile_tft_flexible.png")
41plt.show()
42print("Flexible TFT quantile plot generated.")

Example Output Plot (Flexible TFT):

Flexible TFT Quantile Forecast

Visualization of the quantile forecast (median and interval) against actual validation data using the flexible TemporalFusionTransformer.


Example 2: Quantile Forecasting with Stricter TFT

This example uses the TFT class, which requires static, dynamic, and future inputs to be provided and non-None.

Workflow:

  1. Generate synthetic data with static, dynamic, and future features.

  2. Use reshape_xtft_data() to prepare the three separate input arrays and multi-step targets.

  3. Define and compile the stricter TFT model with quantile outputs.

  4. Train the model using the required three-part input list.

  5. Make and visualize quantile predictions.

Step 2.1: Imports for Stricter TFT

Additional imports like StandardScaler and reshape_xtft_data.

1# Imports from previous example are assumed
2from sklearn.preprocessing import StandardScaler
3from fusionlab.nn.transformers import TFT as TFTStricter # Alias
4from fusionlab.nn.utils import reshape_xtft_data
5print("\nLibraries imported for Stricter TFT Quantile Example.")

Step 2.2: Generate Synthetic Data (Multi-Feature)

We create data with distinct static, dynamic, and future features.

 1# define your RNG (choose any seed for reproducibility)
 2rng = np.random.default_rng(seed=42)
 3n_items_strict = 2
 4n_timesteps_strict = 60 # More data
 5date_rng_strict = pd.date_range(
 6    start='2020-01-01', periods=n_timesteps_strict, freq='MS'
 7    )
 8df_list_strict = []
 9for item_id in range(n_items_strict):
10    time_idx = np.arange(n_timesteps_strict)
11    value = (50 + item_id * 20 + time_idx * 0.8 +
12             15 * np.sin(2 * np.pi * time_idx / 12) +
13             rng.normal(0, 5, n_timesteps_strict)) # Use main rng
14    static_val = item_id * 10
15    future_val = (time_idx % 6 == 0).astype(float) # Event every 6 months
16    item_df = pd.DataFrame({
17        'Date': date_rng_strict, 'ItemID': item_id,
18        'StaticFeature': static_val,
19        'Month': date_rng_strict.month, # Dynamic
20        'ValueLag1': pd.Series(value).shift(1), # Dynamic
21        'FutureEvent': future_val, # Future
22        'TargetValue': value
23    })
24    df_list_strict.append(item_df)
25df_strict_raw = pd.concat(df_list_strict).dropna().reset_index(drop=True)
26print(f"Generated data shape for stricter TFT: {df_strict_raw.shape}")

Step 2.3: Define Features & Scale

Define column roles and scale numerical features.

 1target_col_s = 'TargetValue'
 2dt_col_s = 'Date'
 3static_cols_s = ['ItemID', 'StaticFeature']
 4dynamic_cols_s = ['Month', 'ValueLag1']
 5future_cols_s = ['FutureEvent', 'Month'] # Month can be known future
 6spatial_cols_s = ['ItemID']
 7
 8scaler_s = StandardScaler()
 9cols_to_scale_s = ['TargetValue', 'ValueLag1', 'StaticFeature']
10df_strict_scaled = df_strict_raw.copy()
11df_strict_scaled[cols_to_scale_s] = scaler_s.fit_transform(
12    df_strict_scaled[cols_to_scale_s]
13    )
14print("Numerical features scaled for stricter TFT.")

Step 2.4: Prepare Sequences with reshape_xtft_data

This utility separates static, dynamic, and future features into the required arrays.

 1time_steps_s = 12         # 1 year lookback
 2forecast_horizon_s = 6    # Predict 6 months
 3
 4s_data, d_data, f_data, t_data = reshape_xtft_data(
 5    df=df_strict_scaled, dt_col=dt_col_s, target_col=target_col_s,
 6    dynamic_cols=dynamic_cols_s, static_cols=static_cols_s,
 7    future_cols=future_cols_s, spatial_cols=spatial_cols_s,
 8    time_steps=time_steps_s, forecast_horizons=forecast_horizon_s,
 9    verbose=0
10)
11# Target shape for loss: (Samples, Horizon, OutputDim=1)
12targets_s = t_data.astype(np.float32) # reshape_xtft_data returns (N,H,1)
13
14print(f"\nStricter TFT - Reshaped Data Shapes:")
15print(f"  Static : {s_data.shape}, Dynamic: {d_data.shape}")
16print(f"  Future : {f_data.shape}, Target : {targets_s.shape}")

Step 2.5: Train/Validation Split of Sequences

Split the generated sequence arrays.

 1val_split_s = 0.2
 2n_samples_s = s_data.shape[0]
 3split_idx_s = int(n_samples_s * (1 - val_split_s))
 4
 5X_s_train, X_s_val = s_data[:split_idx_s], s_data[split_idx_s:]
 6X_d_train, X_d_val = d_data[:split_idx_s], d_data[split_idx_s:]
 7X_f_train, X_f_val = f_data[:split_idx_s], f_data[split_idx_s:]
 8y_t_train, y_t_val = targets_s[:split_idx_s], targets_s[split_idx_s:]
 9
10train_inputs_s = [X_s_train, X_d_train, X_f_train]
11val_inputs_s = [X_s_val, X_d_val, X_f_val]
12print(f"Data split. Train sequences: {len(y_t_train)}")

Step 2.6: Define and Train Stricter TFT Model

Instantiate the stricter TFT class, providing all three input dimensions and the quantiles list.

 1quantiles_s = [0.1, 0.5, 0.9]
 2model_strict = TFTStricter( # Using the aliased stricter TFT
 3    static_input_dim=s_data.shape[-1],
 4    dynamic_input_dim=d_data.shape[-1],
 5    future_input_dim=f_data.shape[-1],
 6    forecast_horizon=forecast_horizon_s,
 7    quantiles=quantiles_s,
 8    output_dim=1, # Univariate target
 9    hidden_units=16, num_heads=2, num_lstm_layers=1, lstm_units=16
10)
11print("\nStricter TFT model for quantiles instantiated.")
12
13loss_fn_s = combined_quantile_loss(quantiles=quantiles_s)
14model_strict.compile(optimizer='adam', loss=loss_fn_s)
15print("Stricter TFT compiled with quantile loss.")
16
17print("\nStarting stricter TFT training (quantile)...")
18history_s = model_strict.fit(
19    train_inputs_s, # Must be [Static, Dynamic, Future]
20    y_t_train,
21    validation_data=(val_inputs_s, y_t_val),
22    epochs=5, batch_size=16, verbose=0
23)
24print("Stricter TFT training finished.")
25if history_s and history_s.history.get('val_loss'):
26    val_loss_s = history_s.history['val_loss'][-1]
27    print(f"Final validation loss (stricter TFT): {val_loss_s:.4f}")

Step 2.7: Make Predictions and Visualize (Stricter TFT)

Predictions and visualization follow a similar pattern.

 1print("\nMaking quantile predictions (stricter TFT)...")
 2val_predictions_s = model_strict.predict(val_inputs_s, verbose=0)
 3print(f"Prediction output shape: {val_predictions_s.shape}")
 4
 5# Inverse transform (assuming 'TargetValue' was scaled by scaler_s)
 6# For simplicity, visualization of inverse transformed values is omitted here
 7# but would follow the same logic as Example 1, using scaler_s.
 8
 9# Plot one sample from validation set
10sample_to_plot_s = 0
11actual_s = y_t_val[sample_to_plot_s, :, 0] # Scaled
12pred_q_s = val_predictions_s[sample_to_plot_s, :, :] # Scaled
13
14# Create a dummy time axis for this sample's forecast
15plot_time_axis_s = np.arange(forecast_horizon_s)
16
17plt.figure(figsize=(12, 6))
18plt.plot(plot_time_axis_s, actual_s, label='Actual (Scaled)',
19         marker='o', linestyle='--')
20plt.plot(plot_time_axis_s, pred_q_s[:, 1], # Median
21         label='Predicted Median (q=0.5, Scaled)', marker='x')
22plt.fill_between(
23    plot_time_axis_s, pred_q_s[:, 0], pred_q_s[:, 2],
24    color='gray', alpha=0.3,
25    label='Prediction Interval (q=0.1 to q=0.9, Scaled)'
26)
27plt.title('Stricter TFT Quantile Forecast (Validation Sample - Scaled)')
28plt.xlabel('Forecast Step'); plt.ylabel('Scaled Value')
29plt.legend(); plt.grid(True); plt.tight_layout()
30# plt.savefig("docs/source/images/forecasting_quantile_tft_stricter.png")
31plt.show()
32print("Stricter TFT quantile plot generated.")

Example Output Plot (Stricter TFT - Scaled Values):

Stricter TFT Quantile Forecast

Visualization of the quantile forecast using the stricter TFT model (showing scaled values for simplicity).