Exercise: Quantile Forecasting with TFT Variants¶
Welcome to this exercise on quantile forecasting! Quantile forecasts provide an estimate of the prediction uncertainty by predicting multiple quantiles (e.g., 10th, 50th, 90th percentiles) instead of a single point value. This is crucial for understanding the range of potential future outcomes.
In this guide, you’ll learn to use two Temporal Fusion Transformer
variants from fusionlab-learn for this task:
1. The flexible TemporalFusionTransformer.
2. The stricter TFT.
Learning Objectives:
Prepare data for multi-step quantile forecasting.
Instantiate and compile TFT models for quantile outputs using the quantiles parameter and
combined_quantile_loss().Correctly format inputs for both flexible (optional inputs) and stricter (all inputs required) TFT variants.
Train the models and interpret their multi-quantile predictions.
Visualize quantile forecasts to represent prediction uncertainty.
Let’s begin!
Prerequisites¶
Ensure you have fusionlab-learn and its common dependencies
installed. For visualizations, matplotlib is also needed.
pip install fusionlab-learn matplotlib scikit-learn joblib
Exercise 1: Quantile Forecasting with Flexible TemporalFusionTransformer¶
In this part, we’ll use the flexible
TemporalFusionTransformer with
only dynamic (past observed) features to produce multi-step quantile
forecasts.
Workflow: 1. Generate synthetic time series data. 2. Prepare sequences for multi-step forecasting. 3. Define and compile the flexible TFT for quantile output. 4. Train the model. 5. Make and visualize quantile predictions.
- Step 1.1: Imports and Setup
Import necessary libraries and
fusionlabcomponents.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import matplotlib.pyplot as plt
5import warnings
6import os
7
8# FusionLab imports
9from fusionlab.nn.transformers import TemporalFusionTransformer
10from fusionlab.nn.utils import create_sequences
11from fusionlab.nn.losses import combined_quantile_loss
12# for preparing dummy tensor when static and future are None
13from fusionlab.nn.utils import prepare_model_inputs
14
15# Suppress warnings and TF logs
16warnings.filterwarnings('ignore')
17tf.get_logger().setLevel('ERROR')
18if hasattr(tf, 'autograph'):
19 tf.autograph.set_verbosity(0)
20
21# Directory for saving outputs
22exercise_output_dir_quant = "./quantile_forecast_exercise_outputs"
23os.makedirs(exercise_output_dir_quant, exist_ok=True)
24
25print("Libraries imported for Flexible TFT Quantile Exercise.")
Expected Output 1.1:
Libraries imported for Flexible TFT Quantile Exercise.
- Step 1.2: Generate Synthetic Data
We use a simple sine wave with noise.
1np.random.seed(42) # For reproducibility
2tf.random.set_seed(42)
3
4time_flex_q = np.arange(0, 100, 0.1)
5amplitude_flex_q = np.sin(time_flex_q) + \
6 np.random.normal(0, 0.15, len(time_flex_q))
7df_flex_q = pd.DataFrame({'Value': amplitude_flex_q})
8print(f"Generated data shape for flexible TFT: {df_flex_q.shape}")
Expected Output 1.2:
Generated data shape for flexible TFT: (1000, 1)
- Step 1.3: Prepare Sequences for Multi-Step Forecasting
We’ll predict the next 5 time steps using the past 10 steps. Targets are reshaped to (Samples, Horizon, OutputDim).
1sequence_length_flex_q = 10
2forecast_horizon_flex_q = 5 # Predict next 5 steps
3target_col_flex_q = 'Value'
4
5sequences_flex_q, targets_flex_q = create_sequences(
6 df=df_flex_q,
7 sequence_length=sequence_length_flex_q,
8 target_col=target_col_flex_q,
9 forecast_horizon=forecast_horizon_flex_q,
10 verbose=0
11)
12sequences_flex_q = sequences_flex_q.astype(np.float32)
13targets_flex_q = targets_flex_q.reshape(
14 -1, forecast_horizon_flex_q, 1 # OutputDim = 1
15 ).astype(np.float32)
16
17print(f"\nFlexible TFT - Input sequences (X): {sequences_flex_q.shape}")
18print(f"Flexible TFT - Target values (y): {targets_flex_q.shape}")
- Expected Output 1.3:
(Num samples = 1000 - 10 - 5 + 1 = 986)
Flexible TFT - Input sequences (X): (986, 10, 1)
Flexible TFT - Target values (y): (986, 5, 1)
- Step 1.4: Define Flexible TFT Model for Quantile Forecast
Instantiate TemporalFusionTransformer, providing the quantiles list. Static and future input dimensions default to None.
1quantiles_to_predict_flex = [0.1, 0.5, 0.9] # 10th, 50th, 90th
2num_dynamic_features_flex_q = sequences_flex_q.shape[-1]
3
4model_flex_q = TemporalFusionTransformer(
5 dynamic_input_dim=num_dynamic_features_flex_q,
6 forecast_horizon=forecast_horizon_flex_q,
7 output_dim=1, # Univariate target
8 hidden_units=16, num_heads=2,
9 num_lstm_layers=1, lstm_units=16,
10 quantiles=quantiles_to_predict_flex # Enable quantile output
11)
12print("\nFlexible TFT for quantiles instantiated.")
13
14# Compile with combined_quantile_loss
15loss_fn_flex_q = combined_quantile_loss(
16 quantiles=quantiles_to_predict_flex
17 )
18model_flex_q.compile(optimizer='adam', loss=loss_fn_flex_q)
19print("Flexible TFT compiled with quantile loss.")
Expected Output 1.4:
Flexible TFT for quantiles instantiated.
Flexible TFT compiled with quantile loss.
- Step 1.5: Train the Model
Inputs are passed as [None, dynamic_sequences, None] for the [static, dynamic, future] order.
1# Preparing dummy tensor or pass only to the model [sequences_flex_q]
2train_inputs_flex_q = prepare_model_inputs(
3 dynamic_input=sequences_flex_q,
4 static_input=None, future_input=None,
5 model_type= 'strict')
6
7# train_inputs_flex_q Order: [Static, Dynamic, Future]
8# Try also : train_inputs_flex_q =[sequences_flex_q]
9print("\nStarting flexible TFT training (quantile)...")
10history_flex_q = model_flex_q.fit(
11 train_inputs_flex_q,
12 targets_flex_q, # Shape (Samples, Horizon, 1)
13 epochs=10, # Train a bit longer for quantiles
14 batch_size=32,
15 validation_split=0.2,
16 verbose=1 # Show progress
17)
18print("Flexible TFT training finished.")
19if history_flex_q and history_flex_q.history.get('val_loss'):
20 val_loss_q = history_flex_q.history['val_loss'][-1]
21 print(f"Final validation loss (quantile): {val_loss_q:.4f}")
- Expected Output 1.5:
(Keras training logs for 10 epochs, then final loss: loss may varie)
Starting flexible TFT training (quantile)...
Epoch 1/10
25/25 [==============================] - 7s 47ms/step - loss: 0.2302 - val_loss: 0.1550
Epoch 2/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1629 - val_loss: 0.1312
Epoch 3/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1470 - val_loss: 0.1179
Epoch 4/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1354 - val_loss: 0.1136
Epoch 5/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1278 - val_loss: 0.1080
Epoch 6/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1255 - val_loss: 0.1071
Epoch 7/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1212 - val_loss: 0.1019
Epoch 8/10
25/25 [==============================] - 0s 9ms/step - loss: 0.1161 - val_loss: 0.1003
Epoch 9/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1113 - val_loss: 0.0974
Epoch 10/10
25/25 [==============================] - 0s 8ms/step - loss: 0.1060 - val_loss: 0.0890
Flexible TFT training finished.
Final validation loss (quantile): 0.0890
- Step 1.6: Make and Visualize Quantile Predictions
Predictions will have a shape (Batch, Horizon, NumQuantiles). We visualize the median and the prediction interval.
1num_samples_total_flex_q = sequences_flex_q.shape[0]
2val_start_idx_flex_q = int(num_samples_total_flex_q * (1 - 0.2))
3
4val_dynamic_flex_q = sequences_flex_q[val_start_idx_flex_q:]
5val_actuals_flex_q = targets_flex_q[val_start_idx_flex_q:]
6
7val_inputs_list_flex_q = [val_dynamic_flex_q]
8
9print("\nMaking quantile predictions (flexible TFT)...")
10val_predictions_flex_q = model_flex_q.predict(
11 val_inputs_list_flex_q, verbose=0
12 )
13print(f"Prediction output shape: {val_predictions_flex_q.shape}")
14
15# --- Visualization for one sample ---
16sample_to_plot_flex_q = 0 # Plot the first sample from validation
17actual_vals_plot_flex = val_actuals_flex_q[sample_to_plot_flex_q, :, 0]
18pred_quantiles_plot_flex = val_predictions_flex_q[sample_to_plot_flex_q, :, :]
19
20# Align time axis for plotting
21plot_time_flex_q = time_flex_q[
22 val_start_idx_flex_q + sequence_length_flex_q + sample_to_plot_flex_q : \
23 val_start_idx_flex_q + sequence_length_flex_q + \
24 sample_to_plot_flex_q + forecast_horizon_flex_q
25 ]
26
27plt.figure(figsize=(12, 6))
28plt.plot(plot_time_flex_q, actual_vals_plot_flex,
29 label='Actual Value', marker='o', linestyle='--')
30plt.plot(plot_time_flex_q, pred_quantiles_plot_flex[:, 1], # Median (0.5)
31 label='Predicted Median (q=0.5)', marker='x')
32plt.fill_between(
33 plot_time_flex_q,
34 pred_quantiles_plot_flex[:, 0], # Lower quantile (q=0.1)
35 pred_quantiles_plot_flex[:, 2], # Upper quantile (q=0.9)
36 color='skyblue', alpha=0.4,
37 label='Prediction Interval (q0.1-q0.9)'
38)
39plt.title('Flexible TFT Quantile Forecast (Dynamic Inputs Only)')
40plt.xlabel('Time'); plt.ylabel('Value')
41plt.legend(); plt.grid(True); plt.tight_layout()
42# To save for documentation:
43# plt.savefig(os.path.join(exercise_output_dir_quant,
44# "exercise_quantile_tft_flexible.png"))
45plt.show()
46print("Flexible TFT quantile plot generated.")
Expected Plot 1.6:
Visualization of the quantile forecast (median and interval) against actual validation data using the flexible TemporalFusionTransformer.¶
Exercise 2: Quantile Forecasting with Stricter TFT¶
Now, we use the stricter TFT
class, which requires static, dynamic, and future inputs.
Workflow:
1. Generate synthetic data with all three feature types.
2. Define feature roles, encode categoricals, and scale numerics.
3. Use reshape_xtft_data() to prepare
the three distinct input arrays.
Define and compile the stricter TFT for quantile output.
Train the model.
Make and visualize quantile predictions.
- Step 2.1: Imports for Stricter TFT
(Most imports are already done. We might need LabelEncoder.)
1from sklearn.preprocessing import LabelEncoder # For ItemID
2from fusionlab.datasets.make import make_multi_feature_time_series
3from fusionlab.nn.transformers import TFT as TFTStricter # Alias
4from fusionlab.nn.utils import reshape_xtft_data
5
6print("\nLibraries ready for Stricter TFT Quantile Exercise.")
- Step 2.2: Generate Synthetic Multi-Feature Data
We use make_multi_feature_time_series for convenience.
1n_items_strict_q = 2
2n_timesteps_strict_q = 60
3rng_seed_strict_q = 123
4np.random.seed(rng_seed_strict_q)
5tf.random.set_seed(rng_seed_strict_q)
6
7data_bunch_strict_q = make_multi_feature_time_series(
8 n_series=n_items_strict_q,
9 n_timesteps=n_timesteps_strict_q,
10 freq='D', seasonality_period=7, seed=rng_seed_strict_q,
11 as_frame=False # Get Bunch object
12)
13df_raw_strict_q = data_bunch_strict_q.frame.copy()
14print(f"\nGenerated data for stricter TFT: {df_raw_strict_q.shape}")
Expected Output 2.2:
Generated data for stricter TFT: (120, 9)
- Step 2.3: Define Features, Encode, and Scale
We use feature lists from data_bunch_strict_q. series_id (our ItemID) is numerical from the data generator. Numerical features are scaled.
1target_col_sq = data_bunch_strict_q.target_col
2dt_col_sq = data_bunch_strict_q.dt_col
3static_cols_sq = data_bunch_strict_q.static_features
4dynamic_cols_sq = data_bunch_strict_q.dynamic_features
5future_cols_sq = data_bunch_strict_q.future_features
6spatial_cols_sq = [data_bunch_strict_q.spatial_id_col]
7
8df_processed_sq = df_raw_strict_q.copy()
9scalers_sq = {}
10num_cols_to_scale_sq = [
11 'base_level', 'dynamic_cov', 'target_lag1', target_col_sq
12 ]
13cols_actually_scaled_sq = []
14for col in num_cols_to_scale_sq:
15 if col in df_processed_sq.columns and \
16 pd.api.types.is_numeric_dtype(df_processed_sq[col]):
17 scaler = StandardScaler()
18 df_processed_sq[col] = scaler.fit_transform(df_processed_sq[[col]])
19 scalers_sq[col] = scaler
20 cols_actually_scaled_sq.append(col)
21print(f"\nNumerical features scaled for stricter TFT: {cols_actually_scaled_sq}")
Expected Output 2.3:
Numerical features scaled for stricter TFT: ['base_level', 'dynamic_cov', 'target_lag1', 'target']
- Step 2.4: Prepare Sequences with `reshape_xtft_data`
This utility separates features into static, dynamic, and future arrays.
1time_steps_sq = 10
2forecast_horizon_sq = 5
3
4s_data_sq, d_data_sq, f_data_sq, t_data_sq = reshape_xtft_data(
5 df=df_processed_sq, dt_col=dt_col_sq, target_col=target_col_sq,
6 dynamic_cols=dynamic_cols_sq, static_cols=static_cols_sq,
7 future_cols=future_cols_sq, spatial_cols=spatial_cols_sq,
8 time_steps=time_steps_sq, forecast_horizons=forecast_horizon_sq,
9 verbose=0
10)
11targets_sq = t_data_sq.astype(np.float32)
12print(f"\nStricter TFT - Reshaped Data Shapes:")
13print(f" Static : {s_data_sq.shape}, Dynamic: {d_data_sq.shape}")
14print(f" Future : {f_data_sq.shape}, Target : {targets_sq.shape}")
- Expected Output 2.4:
(Shapes depend on generation params, T, H. For N=2, TS=60, T=10, H=5: Seq/series = 60-10-5+1 = 46. Total = 2*46 = 92)
Stricter TFT - Reshaped Data Shapes:
Static : (92, 2), Dynamic: (92, 10, 4)
Future : (92, 15, 3), Target : (92, 5, 1)
- Step 2.5: Train/Validation Split
(This step is similar to Exercise 1, using the `_sq` suffixed variables)
1val_split_sq_frac = 0.2
2n_samples_sq_total = s_data_sq.shape[0]
3split_idx_sq_val = int(n_samples_sq_total * (1 - val_split_sq_frac))
4
5X_s_train_sq, X_s_val_sq = s_data_sq[:split_idx_sq_val], s_data_sq[split_idx_sq_val:]
6X_d_train_sq, X_d_val_sq = d_data_sq[:split_idx_sq_val], d_data_sq[split_idx_sq_val:]
7X_f_train_sq, X_f_val_sq = f_data_sq[:split_idx_sq_val], f_data_sq[split_idx_sq_val:]
8y_t_train_sq, y_t_val_sq = targets_sq[:split_idx_sq_val], targets_sq[split_idx_sq_val:]
9
10train_inputs_strict_q = [X_s_train_sq, X_d_train_sq, X_f_train_sq]
11val_inputs_strict_q = [X_s_val_sq, X_d_val_sq, X_f_val_sq]
12print(f"\nData split for stricter TFT. Train samples: {len(y_t_train_sq)}")
13# [out]: Data split for stricter TFT. Train samples: 73
- Step 2.6: Define and Train Stricter `TFT` Model
Instantiate the stricter TFT class, providing all three input dimensions and the quantiles list.
1quantiles_strict_q = [0.1, 0.5, 0.9]
2model_strict_q_ex = TFTStricter(
3 static_input_dim=s_data_sq.shape[-1],
4 dynamic_input_dim=d_data_sq.shape[-1],
5 future_input_dim=f_data_sq.shape[-1],
6 forecast_horizon=forecast_horizon_sq,
7 quantiles=quantiles_strict_q, output_dim=1,
8 hidden_units=16, num_heads=2, num_lstm_layers=1, lstm_units=16
9)
10print("\nStricter TFT model for quantiles instantiated.")
11
12loss_fn_strict_q = combined_quantile_loss(quantiles=quantiles_strict_q)
13model_strict_q_ex.compile(optimizer='adam', loss=loss_fn_strict_q)
14print("Stricter TFT compiled.")
15
16print("\nStarting stricter TFT training (quantile)...")
17history_strict_q = model_strict_q_ex.fit(
18 train_inputs_strict_q, y_t_train_sq,
19 validation_data=(val_inputs_strict_q, y_t_val_sq),
20 epochs=5, batch_size=16, verbose=0
21)
22print("Stricter TFT training finished.")
23if history_strict_q and history_strict_q.history.get('val_loss'):
24 val_loss_sq = history_strict_q.history['val_loss'][-1]
25 print(f"Final validation loss (stricter TFT, quantile): {val_loss_sq:.4f}")
Expected Output 2.6:
Stricter TFT model for quantiles instantiated.
Stricter TFT compiled.
Starting stricter TFT training (quantile)...
Stricter TFT training finished.
Final validation loss (stricter TFT, quantile): 0.1147
- Step 2.7: Predictions and Visualization (Stricter TFT)
(Prediction and visualization are similar to Exercise 1, using `model_strict_q_ex`, `val_inputs_strict_q`, `y_t_val_sq`, and `scalers_sq`)
1print("\nMaking quantile predictions (stricter TFT)...")
2val_pred_scaled_sq = model_strict_q_ex.predict(val_inputs_strict_q, verbose=0)
3print(f"Prediction output shape: {val_pred_scaled_sq.shape}")
4
5# Inverse transform (simplified, assuming target was scaled)
6target_scaler_sq = scalers_sq.get(target_col_sq)
7if target_scaler_sq:
8 pred_flat = val_pred_scaled_sq.reshape(-1, len(quantiles_strict_q))
9 actual_flat = y_t_val_sq.reshape(-1, 1)
10 pred_inv = target_scaler_sq.inverse_transform(pred_flat)
11 actual_inv = target_scaler_sq.inverse_transform(actual_flat)
12 pred_final_sq = pred_inv.reshape(val_pred_scaled_sq.shape)
13 actual_final_sq = actual_inv.reshape(y_t_val_sq.shape)
14else:
15 pred_final_sq = val_pred_scaled_sq
16 actual_final_sq = y_t_val_sq
17
18# Plot one sample
19sample_idx_sq = 0
20plt.figure(figsize=(10, 5))
21plt.plot(actual_final_sq[sample_idx_sq, :, 0], label='Actual', marker='o')
22plt.plot(pred_final_sq[sample_idx_sq, :, 1], label='Median Pred', marker='x')
23plt.fill_between(np.arange(forecast_horizon_sq),
24 pred_final_sq[sample_idx_sq, :, 0],
25 pred_final_sq[sample_idx_sq, :, 2],
26 color='skyblue', alpha=0.4, label='Interval')
27plt.title('Stricter TFT Quantile Forecast')
28plt.legend(); plt.grid(True)
29# plt.savefig(os.path.join(exercise_output_dir_quant,
30# "exercise_quantile_tft_stricter.png"))
31plt.show()
Expected Plot 2.7:
Visualization of the quantile forecast using the stricter TFT model.¶
Discussion of Exercise¶
In this exercise, you explored quantile forecasting with two TFT variants:
Flexible `TemporalFusionTransformer`: Demonstrated with only dynamic inputs, showcasing its adaptability. Inputs are provided as [None, dynamic_array, None].
Stricter `TFT`: Showcased with all three input types (static, dynamic, future) generated via
make_multi_feature_time_series()and prepared usingreshape_xtft_data(). Inputs are provided as [static_array, dynamic_array, future_array].
Key takeaways include:
Setting the quantiles parameter in the model’s __init__ method.
Using
combined_quantile_loss()for training.Understanding that the model’s output shape changes to include the number of quantiles.
Visualizing prediction intervals to assess forecast uncertainty.
This exercise provides a foundation for building more complex probabilistic forecasting models.