Exercise: Quantile Forecasting with TFT Variants

Welcome to this exercise on quantile forecasting! Quantile forecasts provide an estimate of the prediction uncertainty by predicting multiple quantiles (e.g., 10th, 50th, 90th percentiles) instead of a single point value. This is crucial for understanding the range of potential future outcomes.

In this guide, you’ll learn to use two Temporal Fusion Transformer variants from fusionlab-learn for this task: 1. The flexible TemporalFusionTransformer. 2. The stricter TFT.

Learning Objectives:

  • Prepare data for multi-step quantile forecasting.

  • Instantiate and compile TFT models for quantile outputs using the quantiles parameter and combined_quantile_loss().

  • Correctly format inputs for both flexible (optional inputs) and stricter (all inputs required) TFT variants.

  • Train the models and interpret their multi-quantile predictions.

  • Visualize quantile forecasts to represent prediction uncertainty.

Let’s begin!

Prerequisites

Ensure you have fusionlab-learn and its common dependencies installed. For visualizations, matplotlib is also needed.

pip install fusionlab-learn matplotlib scikit-learn joblib

Exercise 1: Quantile Forecasting with Flexible TemporalFusionTransformer

In this part, we’ll use the flexible TemporalFusionTransformer with only dynamic (past observed) features to produce multi-step quantile forecasts.

Workflow: 1. Generate synthetic time series data. 2. Prepare sequences for multi-step forecasting. 3. Define and compile the flexible TFT for quantile output. 4. Train the model. 5. Make and visualize quantile predictions.

Step 1.1: Imports and Setup

Import necessary libraries and fusionlab components.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5import warnings
 6import os
 7
 8# FusionLab imports
 9from fusionlab.nn.transformers import TemporalFusionTransformer
10from fusionlab.nn.utils import create_sequences
11from fusionlab.nn.losses import combined_quantile_loss
12# for preparing dummy tensor when static and future are None
13from fusionlab.nn.utils import prepare_model_inputs
14
15# Suppress warnings and TF logs
16warnings.filterwarnings('ignore')
17tf.get_logger().setLevel('ERROR')
18if hasattr(tf, 'autograph'):
19    tf.autograph.set_verbosity(0)
20
21# Directory for saving outputs
22exercise_output_dir_quant = "./quantile_forecast_exercise_outputs"
23os.makedirs(exercise_output_dir_quant, exist_ok=True)
24
25print("Libraries imported for Flexible TFT Quantile Exercise.")

Expected Output 1.1:

Libraries imported for Flexible TFT Quantile Exercise.
Step 1.2: Generate Synthetic Data

We use a simple sine wave with noise.

1np.random.seed(42) # For reproducibility
2tf.random.set_seed(42)
3
4time_flex_q = np.arange(0, 100, 0.1)
5amplitude_flex_q = np.sin(time_flex_q) + \
6                   np.random.normal(0, 0.15, len(time_flex_q))
7df_flex_q = pd.DataFrame({'Value': amplitude_flex_q})
8print(f"Generated data shape for flexible TFT: {df_flex_q.shape}")

Expected Output 1.2:

Generated data shape for flexible TFT: (1000, 1)
Step 1.3: Prepare Sequences for Multi-Step Forecasting

We’ll predict the next 5 time steps using the past 10 steps. Targets are reshaped to (Samples, Horizon, OutputDim).

 1sequence_length_flex_q = 10
 2forecast_horizon_flex_q = 5 # Predict next 5 steps
 3target_col_flex_q = 'Value'
 4
 5sequences_flex_q, targets_flex_q = create_sequences(
 6    df=df_flex_q,
 7    sequence_length=sequence_length_flex_q,
 8    target_col=target_col_flex_q,
 9    forecast_horizon=forecast_horizon_flex_q,
10    verbose=0
11)
12sequences_flex_q = sequences_flex_q.astype(np.float32)
13targets_flex_q = targets_flex_q.reshape(
14    -1, forecast_horizon_flex_q, 1 # OutputDim = 1
15    ).astype(np.float32)
16
17print(f"\nFlexible TFT - Input sequences (X): {sequences_flex_q.shape}")
18print(f"Flexible TFT - Target values (y): {targets_flex_q.shape}")
Expected Output 1.3:

(Num samples = 1000 - 10 - 5 + 1 = 986)

Flexible TFT - Input sequences (X): (986, 10, 1)
Flexible TFT - Target values (y): (986, 5, 1)
Step 1.4: Define Flexible TFT Model for Quantile Forecast

Instantiate TemporalFusionTransformer, providing the quantiles list. Static and future input dimensions default to None.

 1quantiles_to_predict_flex = [0.1, 0.5, 0.9] # 10th, 50th, 90th
 2num_dynamic_features_flex_q = sequences_flex_q.shape[-1]
 3
 4model_flex_q = TemporalFusionTransformer(
 5    dynamic_input_dim=num_dynamic_features_flex_q,
 6    forecast_horizon=forecast_horizon_flex_q,
 7    output_dim=1, # Univariate target
 8    hidden_units=16, num_heads=2,
 9    num_lstm_layers=1, lstm_units=16,
10    quantiles=quantiles_to_predict_flex # Enable quantile output
11)
12print("\nFlexible TFT for quantiles instantiated.")
13
14# Compile with combined_quantile_loss
15loss_fn_flex_q = combined_quantile_loss(
16    quantiles=quantiles_to_predict_flex
17    )
18model_flex_q.compile(optimizer='adam', loss=loss_fn_flex_q)
19print("Flexible TFT compiled with quantile loss.")

Expected Output 1.4:

Flexible TFT for quantiles instantiated.
Flexible TFT compiled with quantile loss.
Step 1.5: Train the Model

Inputs are passed as [None, dynamic_sequences, None] for the [static, dynamic, future] order.

 1# Preparing dummy tensor or pass only to the model [sequences_flex_q]
 2train_inputs_flex_q = prepare_model_inputs(
 3    dynamic_input=sequences_flex_q,
 4     static_input=None, future_input=None,
 5     model_type= 'strict')
 6
 7# train_inputs_flex_q  Order: [Static, Dynamic, Future]
 8# Try also : train_inputs_flex_q =[sequences_flex_q]
 9print("\nStarting flexible TFT training (quantile)...")
10history_flex_q = model_flex_q.fit(
11    train_inputs_flex_q,
12    targets_flex_q, # Shape (Samples, Horizon, 1)
13    epochs=10,      # Train a bit longer for quantiles
14    batch_size=32,
15    validation_split=0.2,
16    verbose=1       # Show progress
17)
18print("Flexible TFT training finished.")
19if history_flex_q and history_flex_q.history.get('val_loss'):
20    val_loss_q = history_flex_q.history['val_loss'][-1]
21    print(f"Final validation loss (quantile): {val_loss_q:.4f}")
Expected Output 1.5:

(Keras training logs for 10 epochs, then final loss: loss may varie)

Starting flexible TFT training (quantile)...
Epoch 1/10
25/25 [==============================] - 7s 47ms/step - loss: 0.2302 - val_loss: 0.1550
Epoch 2/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1629 - val_loss: 0.1312
Epoch 3/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1470 - val_loss: 0.1179
Epoch 4/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1354 - val_loss: 0.1136
Epoch 5/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1278 - val_loss: 0.1080
Epoch 6/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1255 - val_loss: 0.1071
Epoch 7/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1212 - val_loss: 0.1019
Epoch 8/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1161 - val_loss: 0.1003
Epoch 9/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1113 - val_loss: 0.0974
Epoch 10/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1060 - val_loss: 0.0890
Flexible TFT training finished.
Final validation loss (quantile): 0.0890
Step 1.6: Make and Visualize Quantile Predictions

Predictions will have a shape (Batch, Horizon, NumQuantiles). We visualize the median and the prediction interval.

 1num_samples_total_flex_q = sequences_flex_q.shape[0]
 2val_start_idx_flex_q = int(num_samples_total_flex_q * (1 - 0.2))
 3
 4val_dynamic_flex_q = sequences_flex_q[val_start_idx_flex_q:]
 5val_actuals_flex_q = targets_flex_q[val_start_idx_flex_q:]
 6
 7val_inputs_list_flex_q = [val_dynamic_flex_q]
 8
 9print("\nMaking quantile predictions (flexible TFT)...")
10val_predictions_flex_q = model_flex_q.predict(
11    val_inputs_list_flex_q, verbose=0
12    )
13print(f"Prediction output shape: {val_predictions_flex_q.shape}")
14
15# --- Visualization for one sample ---
16sample_to_plot_flex_q = 0 # Plot the first sample from validation
17actual_vals_plot_flex = val_actuals_flex_q[sample_to_plot_flex_q, :, 0]
18pred_quantiles_plot_flex = val_predictions_flex_q[sample_to_plot_flex_q, :, :]
19
20# Align time axis for plotting
21plot_time_flex_q = time_flex_q[
22    val_start_idx_flex_q + sequence_length_flex_q + sample_to_plot_flex_q : \
23    val_start_idx_flex_q + sequence_length_flex_q + \
24        sample_to_plot_flex_q + forecast_horizon_flex_q
25    ]
26
27plt.figure(figsize=(12, 6))
28plt.plot(plot_time_flex_q, actual_vals_plot_flex,
29         label='Actual Value', marker='o', linestyle='--')
30plt.plot(plot_time_flex_q, pred_quantiles_plot_flex[:, 1], # Median (0.5)
31         label='Predicted Median (q=0.5)', marker='x')
32plt.fill_between(
33    plot_time_flex_q,
34    pred_quantiles_plot_flex[:, 0], # Lower quantile (q=0.1)
35    pred_quantiles_plot_flex[:, 2], # Upper quantile (q=0.9)
36    color='skyblue', alpha=0.4,
37    label='Prediction Interval (q0.1-q0.9)'
38)
39plt.title('Flexible TFT Quantile Forecast (Dynamic Inputs Only)')
40plt.xlabel('Time'); plt.ylabel('Value')
41plt.legend(); plt.grid(True); plt.tight_layout()
42# To save for documentation:
43# plt.savefig(os.path.join(exercise_output_dir_quant,
44#                          "exercise_quantile_tft_flexible.png"))
45plt.show()
46print("Flexible TFT quantile plot generated.")

Expected Plot 1.6:

Flexible TFT Quantile Forecast Exercise

Visualization of the quantile forecast (median and interval) against actual validation data using the flexible TemporalFusionTransformer.


Exercise 2: Quantile Forecasting with Stricter TFT

Now, we use the stricter TFT class, which requires static, dynamic, and future inputs.

Workflow: 1. Generate synthetic data with all three feature types. 2. Define feature roles, encode categoricals, and scale numerics. 3. Use reshape_xtft_data() to prepare

the three distinct input arrays.

  1. Define and compile the stricter TFT for quantile output.

  2. Train the model.

  3. Make and visualize quantile predictions.

Step 2.1: Imports for Stricter TFT

(Most imports are already done. We might need LabelEncoder.)

1from sklearn.preprocessing import LabelEncoder # For ItemID
2from fusionlab.datasets.make import make_multi_feature_time_series
3from fusionlab.nn.transformers import TFT as TFTStricter # Alias
4from fusionlab.nn.utils import reshape_xtft_data
5
6print("\nLibraries ready for Stricter TFT Quantile Exercise.")
Step 2.2: Generate Synthetic Multi-Feature Data

We use make_multi_feature_time_series for convenience.

 1n_items_strict_q = 2
 2n_timesteps_strict_q = 60
 3rng_seed_strict_q = 123
 4np.random.seed(rng_seed_strict_q)
 5tf.random.set_seed(rng_seed_strict_q)
 6
 7data_bunch_strict_q = make_multi_feature_time_series(
 8    n_series=n_items_strict_q,
 9    n_timesteps=n_timesteps_strict_q,
10    freq='D', seasonality_period=7, seed=rng_seed_strict_q,
11    as_frame=False # Get Bunch object
12)
13df_raw_strict_q = data_bunch_strict_q.frame.copy()
14print(f"\nGenerated data for stricter TFT: {df_raw_strict_q.shape}")

Expected Output 2.2:

Generated data for stricter TFT: (120, 9)
Step 2.3: Define Features, Encode, and Scale

We use feature lists from data_bunch_strict_q. series_id (our ItemID) is numerical from the data generator. Numerical features are scaled.

 1target_col_sq = data_bunch_strict_q.target_col
 2dt_col_sq = data_bunch_strict_q.dt_col
 3static_cols_sq = data_bunch_strict_q.static_features
 4dynamic_cols_sq = data_bunch_strict_q.dynamic_features
 5future_cols_sq = data_bunch_strict_q.future_features
 6spatial_cols_sq = [data_bunch_strict_q.spatial_id_col]
 7
 8df_processed_sq = df_raw_strict_q.copy()
 9scalers_sq = {}
10num_cols_to_scale_sq = [
11    'base_level', 'dynamic_cov', 'target_lag1', target_col_sq
12    ]
13cols_actually_scaled_sq = []
14for col in num_cols_to_scale_sq:
15    if col in df_processed_sq.columns and \
16       pd.api.types.is_numeric_dtype(df_processed_sq[col]):
17        scaler = StandardScaler()
18        df_processed_sq[col] = scaler.fit_transform(df_processed_sq[[col]])
19        scalers_sq[col] = scaler
20        cols_actually_scaled_sq.append(col)
21print(f"\nNumerical features scaled for stricter TFT: {cols_actually_scaled_sq}")

Expected Output 2.3:

Numerical features scaled for stricter TFT: ['base_level', 'dynamic_cov', 'target_lag1', 'target']
Step 2.4: Prepare Sequences with `reshape_xtft_data`

This utility separates features into static, dynamic, and future arrays.

 1time_steps_sq = 10
 2forecast_horizon_sq = 5
 3
 4s_data_sq, d_data_sq, f_data_sq, t_data_sq = reshape_xtft_data(
 5    df=df_processed_sq, dt_col=dt_col_sq, target_col=target_col_sq,
 6    dynamic_cols=dynamic_cols_sq, static_cols=static_cols_sq,
 7    future_cols=future_cols_sq, spatial_cols=spatial_cols_sq,
 8    time_steps=time_steps_sq, forecast_horizons=forecast_horizon_sq,
 9    verbose=0
10)
11targets_sq = t_data_sq.astype(np.float32)
12print(f"\nStricter TFT - Reshaped Data Shapes:")
13print(f"  Static : {s_data_sq.shape}, Dynamic: {d_data_sq.shape}")
14print(f"  Future : {f_data_sq.shape}, Target : {targets_sq.shape}")
Expected Output 2.4:

(Shapes depend on generation params, T, H. For N=2, TS=60, T=10, H=5: Seq/series = 60-10-5+1 = 46. Total = 2*46 = 92)

Stricter TFT - Reshaped Data Shapes:
  Static : (92, 2), Dynamic: (92, 10, 4)
  Future : (92, 15, 3), Target : (92, 5, 1)
Step 2.5: Train/Validation Split

(This step is similar to Exercise 1, using the `_sq` suffixed variables)

 1val_split_sq_frac = 0.2
 2n_samples_sq_total = s_data_sq.shape[0]
 3split_idx_sq_val = int(n_samples_sq_total * (1 - val_split_sq_frac))
 4
 5X_s_train_sq, X_s_val_sq = s_data_sq[:split_idx_sq_val], s_data_sq[split_idx_sq_val:]
 6X_d_train_sq, X_d_val_sq = d_data_sq[:split_idx_sq_val], d_data_sq[split_idx_sq_val:]
 7X_f_train_sq, X_f_val_sq = f_data_sq[:split_idx_sq_val], f_data_sq[split_idx_sq_val:]
 8y_t_train_sq, y_t_val_sq = targets_sq[:split_idx_sq_val], targets_sq[split_idx_sq_val:]
 9
10train_inputs_strict_q = [X_s_train_sq, X_d_train_sq, X_f_train_sq]
11val_inputs_strict_q = [X_s_val_sq, X_d_val_sq, X_f_val_sq]
12print(f"\nData split for stricter TFT. Train samples: {len(y_t_train_sq)}")
13# [out]: Data split for stricter TFT. Train samples: 73
Step 2.6: Define and Train Stricter `TFT` Model

Instantiate the stricter TFT class, providing all three input dimensions and the quantiles list.

 1quantiles_strict_q = [0.1, 0.5, 0.9]
 2model_strict_q_ex = TFTStricter(
 3    static_input_dim=s_data_sq.shape[-1],
 4    dynamic_input_dim=d_data_sq.shape[-1],
 5    future_input_dim=f_data_sq.shape[-1],
 6    forecast_horizon=forecast_horizon_sq,
 7    quantiles=quantiles_strict_q, output_dim=1,
 8    hidden_units=16, num_heads=2, num_lstm_layers=1, lstm_units=16
 9)
10print("\nStricter TFT model for quantiles instantiated.")
11
12loss_fn_strict_q = combined_quantile_loss(quantiles=quantiles_strict_q)
13model_strict_q_ex.compile(optimizer='adam', loss=loss_fn_strict_q)
14print("Stricter TFT compiled.")
15
16print("\nStarting stricter TFT training (quantile)...")
17history_strict_q = model_strict_q_ex.fit(
18    train_inputs_strict_q, y_t_train_sq,
19    validation_data=(val_inputs_strict_q, y_t_val_sq),
20    epochs=5, batch_size=16, verbose=0
21)
22print("Stricter TFT training finished.")
23if history_strict_q and history_strict_q.history.get('val_loss'):
24    val_loss_sq = history_strict_q.history['val_loss'][-1]
25    print(f"Final validation loss (stricter TFT, quantile): {val_loss_sq:.4f}")

Expected Output 2.6:

Stricter TFT model for quantiles instantiated.
Stricter TFT compiled.

Starting stricter TFT training (quantile)...
Stricter TFT training finished.
Final validation loss (stricter TFT, quantile): 0.1147
Step 2.7: Predictions and Visualization (Stricter TFT)

(Prediction and visualization are similar to Exercise 1, using `model_strict_q_ex`, `val_inputs_strict_q`, `y_t_val_sq`, and `scalers_sq`)

 1print("\nMaking quantile predictions (stricter TFT)...")
 2val_pred_scaled_sq = model_strict_q_ex.predict(val_inputs_strict_q, verbose=0)
 3print(f"Prediction output shape: {val_pred_scaled_sq.shape}")
 4
 5# Inverse transform (simplified, assuming target was scaled)
 6target_scaler_sq = scalers_sq.get(target_col_sq)
 7if target_scaler_sq:
 8    pred_flat = val_pred_scaled_sq.reshape(-1, len(quantiles_strict_q))
 9    actual_flat = y_t_val_sq.reshape(-1, 1)
10    pred_inv = target_scaler_sq.inverse_transform(pred_flat)
11    actual_inv = target_scaler_sq.inverse_transform(actual_flat)
12    pred_final_sq = pred_inv.reshape(val_pred_scaled_sq.shape)
13    actual_final_sq = actual_inv.reshape(y_t_val_sq.shape)
14else:
15    pred_final_sq = val_pred_scaled_sq
16    actual_final_sq = y_t_val_sq
17
18# Plot one sample
19sample_idx_sq = 0
20plt.figure(figsize=(10, 5))
21plt.plot(actual_final_sq[sample_idx_sq, :, 0], label='Actual', marker='o')
22plt.plot(pred_final_sq[sample_idx_sq, :, 1], label='Median Pred', marker='x')
23plt.fill_between(np.arange(forecast_horizon_sq),
24                 pred_final_sq[sample_idx_sq, :, 0],
25                 pred_final_sq[sample_idx_sq, :, 2],
26                 color='skyblue', alpha=0.4, label='Interval')
27plt.title('Stricter TFT Quantile Forecast')
28plt.legend(); plt.grid(True)
29# plt.savefig(os.path.join(exercise_output_dir_quant,
30#                          "exercise_quantile_tft_stricter.png"))
31plt.show()

Expected Plot 2.7:

Stricter TFT Quantile Forecast Exercise

Visualization of the quantile forecast using the stricter TFT model.

Discussion of Exercise

In this exercise, you explored quantile forecasting with two TFT variants:

  1. Flexible `TemporalFusionTransformer`: Demonstrated with only dynamic inputs, showcasing its adaptability. Inputs are provided as [None, dynamic_array, None].

  2. Stricter `TFT`: Showcased with all three input types (static, dynamic, future) generated via make_multi_feature_time_series() and prepared using reshape_xtft_data(). Inputs are provided as [static_array, dynamic_array, future_array].

Key takeaways include:

  • Setting the quantiles parameter in the model’s __init__ method.

  • Using combined_quantile_loss() for training.

  • Understanding that the model’s output shape changes to include the number of quantiles.

  • Visualizing prediction intervals to assess forecast uncertainty.

This exercise provides a foundation for building more complex probabilistic forecasting models.