Quantile Forecasting with TFT Variants¶
This guide demonstrates how to configure and train Temporal Fusion
Transformer (TFT) models available in fusionlab-learn to produce
quantile forecasts. Instead of predicting a single point value,
the model predicts multiple quantiles (e.g., 10th, 50th, 90th
percentiles), providing an estimate of the prediction uncertainty.
We will show examples using both:
The flexible
TemporalFusionTransformer(handling optional inputs, demonstrated with dynamic inputs only).The stricter
TFT(requiring all static, dynamic, and future inputs).
Prerequisites¶
Ensure you have fusionlab-learn and its dependencies installed:
pip install fusionlab-learn matplotlib
Example 1: Quantile Forecasting with Flexible TemporalFusionTransformer¶
This example uses only dynamic (past observed) features and modifies the model to output quantile predictions for multiple steps ahead.
Workflow:¶
Generate simple synthetic time series data.
Prepare sequences and multi-step targets using
create_sequences().Instantiate the flexible TemporalFusionTransformer with specified quantiles and output_dim.
Compile the model using
combined_quantile_loss().Train the model.
Interpret and visualize the multi-quantile output.
Step 1.1: Imports and Setup¶
Import standard libraries and fusionlab components.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import matplotlib.pyplot as plt
5import warnings
6import os
7
8# FusionLab imports
9from fusionlab.nn.transformers import TemporalFusionTransformer
10from fusionlab.nn.utils import create_sequences
11from fusionlab.nn.losses import combined_quantile_loss
12
13# Suppress warnings and TF logs
14warnings.filterwarnings('ignore')
15tf.get_logger().setLevel('ERROR')
16if hasattr(tf, 'autograph'):
17 tf.autograph.set_verbosity(0)
18print("Libraries imported for Flexible TFT Quantile Example.")
Step 1.2: Generate Synthetic Data¶
A simple sine wave with noise serves as our univariate time series.
1time_flex = np.arange(0, 100, 0.1)
2amplitude_flex = np.sin(time_flex) + np.random.normal(
3 0, 0.15, len(time_flex)
4 )
5df_flex = pd.DataFrame({'Value': amplitude_flex})
6print(f"Generated data shape for flexible TFT: {df_flex.shape}")
Step 1.3: Prepare Sequences for Multi-Step Forecasting¶
We use past observations to predict multiple future steps. Targets are reshaped to (Samples, Horizon, OutputDim).
1sequence_length_flex = 10
2forecast_horizon_flex = 5 # Predict next 5 steps
3target_col_flex = 'Value'
4
5sequences_flex, targets_flex = create_sequences(
6 df=df_flex,
7 sequence_length=sequence_length_flex,
8 target_col=target_col_flex,
9 forecast_horizon=forecast_horizon_flex,
10 verbose=0
11)
12sequences_flex = sequences_flex.astype(np.float32)
13targets_flex = targets_flex.reshape(
14 -1, forecast_horizon_flex, 1 # OutputDim = 1
15 ).astype(np.float32)
16
17print(f"\nFlexible TFT - Input sequences shape (X): {sequences_flex.shape}")
18print(f"Flexible TFT - Target values shape (y): {targets_flex.shape}")
Step 1.4: Define Flexible TFT Model for Quantile Forecast¶
Instantiate TemporalFusionTransformer, providing the quantiles list. Static and future input dimensions default to None.
1quantiles_to_predict = [0.1, 0.5, 0.9] # 10th, 50th, 90th
2num_dynamic_features_flex = sequences_flex.shape[-1]
3
4model_flex = TemporalFusionTransformer(
5 dynamic_input_dim=num_dynamic_features_flex,
6 # static_input_dim=None, # Default
7 # future_input_dim=None, # Default
8 forecast_horizon=forecast_horizon_flex,
9 output_dim=1, # Univariate target
10 hidden_units=16, num_heads=2,
11 quantiles=quantiles_to_predict, # Enable quantile output
12 num_lstm_layers=1, lstm_units=16
13)
14print("\nFlexible TFT for quantiles instantiated.")
15
16# Compile with combined_quantile_loss
17loss_fn_flex = combined_quantile_loss(quantiles=quantiles_to_predict)
18model_flex.compile(optimizer='adam', loss=loss_fn_flex)
19print("Flexible TFT compiled with quantile loss.")
Step 1.5: Train the Model¶
Inputs are passed as [None, dynamic_sequences, None] to match the [static, dynamic, future] order.
1# Order: [Static, Dynamic, Future]
2train_inputs_flex = sequences_flex # or [sequences_flex] # for single dynamic tensor
3
4print("\nStarting flexible TFT training (quantile)...")
5history_flex = model_flex.fit(
6 train_inputs_flex,
7 targets_flex,
8 epochs=5, batch_size=32, validation_split=0.2, verbose=0
9)
10print("Flexible TFT training finished.")
11if history_flex and history_flex.history.get('val_loss'):
12 val_loss = history_flex.history['val_loss'][-1]
13 print(f"Final validation loss (quantile): {val_loss:.4f}")
Step 1.6: Make and Visualize Quantile Predictions¶
Predictions will have a shape (Batch, Horizon, NumQuantiles). We visualize the median and the prediction interval.
1num_samples_flex = sequences_flex.shape[0]
2val_start_idx_flex = int(num_samples_flex * (1 - 0.2))
3val_dynamic_inputs_flex = sequences_flex[val_start_idx_flex:]
4val_actuals_flex = targets_flex[val_start_idx_flex:]
5
6val_inputs_list_flex = [val_dynamic_inputs_flex]
7
8print("\nMaking quantile predictions (flexible TFT)...")
9val_predictions_quantiles = model_flex.predict(
10 val_inputs_list_flex, verbose=0
11 )
12print(f"Prediction output shape: {val_predictions_quantiles.shape}")
13
14# Visualization for one sample
15sample_to_plot_flex = 0
16actual_vals_flex = val_actuals_flex[sample_to_plot_flex, :, 0]
17pred_quantiles_flex = val_predictions_quantiles[sample_to_plot_flex, :, :]
18
19plot_time_axis_flex = time_flex[
20 val_start_idx_flex + sequence_length_flex + sample_to_plot_flex : \
21 val_start_idx_flex + sequence_length_flex + \
22 sample_to_plot_flex + forecast_horizon_flex
23 ]
24
25plt.figure(figsize=(12, 6))
26plt.plot(plot_time_axis_flex, actual_vals_flex,
27 label='Actual Value', marker='o', linestyle='--')
28plt.plot(plot_time_axis_flex, pred_quantiles_flex[:, 1], # Median (0.5)
29 label='Predicted Median (q=0.5)', marker='x')
30plt.fill_between(
31 plot_time_axis_flex,
32 pred_quantiles_flex[:, 0], # Lower quantile (q=0.1)
33 pred_quantiles_flex[:, 2], # Upper quantile (q=0.9)
34 color='gray', alpha=0.3,
35 label='Prediction Interval (q=0.1 to q=0.9)'
36)
37plt.title('Flexible TFT Quantile Forecast (Dynamic Inputs Only)')
38plt.xlabel('Time'); plt.ylabel('Value')
39plt.legend(); plt.grid(True); plt.tight_layout()
40# plt.savefig("docs/source/images/forecasting_quantile_tft_flexible.png")
41plt.show()
42print("Flexible TFT quantile plot generated.")
Example Output Plot (Flexible TFT):
Visualization of the quantile forecast (median and interval) against actual validation data using the flexible TemporalFusionTransformer.¶
Example 2: Quantile Forecasting with Stricter TFT¶
This example uses the TFT
class, which requires static, dynamic, and future inputs to be
provided and non-None.
Workflow:¶
Generate synthetic data with static, dynamic, and future features.
Use
reshape_xtft_data()to prepare the three separate input arrays and multi-step targets.Define and compile the stricter TFT model with quantile outputs.
Train the model using the required three-part input list.
Make and visualize quantile predictions.
Step 2.1: Imports for Stricter TFT¶
Additional imports like StandardScaler and reshape_xtft_data.
1# Imports from previous example are assumed
2from sklearn.preprocessing import StandardScaler
3from fusionlab.nn.transformers import TFT as TFTStricter # Alias
4from fusionlab.nn.utils import reshape_xtft_data
5print("\nLibraries imported for Stricter TFT Quantile Example.")
Step 2.2: Generate Synthetic Data (Multi-Feature)¶
We create data with distinct static, dynamic, and future features.
1# define your RNG (choose any seed for reproducibility)
2rng = np.random.default_rng(seed=42)
3n_items_strict = 2
4n_timesteps_strict = 60 # More data
5date_rng_strict = pd.date_range(
6 start='2020-01-01', periods=n_timesteps_strict, freq='MS'
7 )
8df_list_strict = []
9for item_id in range(n_items_strict):
10 time_idx = np.arange(n_timesteps_strict)
11 value = (50 + item_id * 20 + time_idx * 0.8 +
12 15 * np.sin(2 * np.pi * time_idx / 12) +
13 rng.normal(0, 5, n_timesteps_strict)) # Use main rng
14 static_val = item_id * 10
15 future_val = (time_idx % 6 == 0).astype(float) # Event every 6 months
16 item_df = pd.DataFrame({
17 'Date': date_rng_strict, 'ItemID': item_id,
18 'StaticFeature': static_val,
19 'Month': date_rng_strict.month, # Dynamic
20 'ValueLag1': pd.Series(value).shift(1), # Dynamic
21 'FutureEvent': future_val, # Future
22 'TargetValue': value
23 })
24 df_list_strict.append(item_df)
25df_strict_raw = pd.concat(df_list_strict).dropna().reset_index(drop=True)
26print(f"Generated data shape for stricter TFT: {df_strict_raw.shape}")
Step 2.3: Define Features & Scale¶
Define column roles and scale numerical features.
1target_col_s = 'TargetValue'
2dt_col_s = 'Date'
3static_cols_s = ['ItemID', 'StaticFeature']
4dynamic_cols_s = ['Month', 'ValueLag1']
5future_cols_s = ['FutureEvent', 'Month'] # Month can be known future
6spatial_cols_s = ['ItemID']
7
8scaler_s = StandardScaler()
9cols_to_scale_s = ['TargetValue', 'ValueLag1', 'StaticFeature']
10df_strict_scaled = df_strict_raw.copy()
11df_strict_scaled[cols_to_scale_s] = scaler_s.fit_transform(
12 df_strict_scaled[cols_to_scale_s]
13 )
14print("Numerical features scaled for stricter TFT.")
Step 2.4: Prepare Sequences with reshape_xtft_data¶
This utility separates static, dynamic, and future features into the required arrays.
1time_steps_s = 12 # 1 year lookback
2forecast_horizon_s = 6 # Predict 6 months
3
4s_data, d_data, f_data, t_data = reshape_xtft_data(
5 df=df_strict_scaled, dt_col=dt_col_s, target_col=target_col_s,
6 dynamic_cols=dynamic_cols_s, static_cols=static_cols_s,
7 future_cols=future_cols_s, spatial_cols=spatial_cols_s,
8 time_steps=time_steps_s, forecast_horizons=forecast_horizon_s,
9 verbose=0
10)
11# Target shape for loss: (Samples, Horizon, OutputDim=1)
12targets_s = t_data.astype(np.float32) # reshape_xtft_data returns (N,H,1)
13
14print(f"\nStricter TFT - Reshaped Data Shapes:")
15print(f" Static : {s_data.shape}, Dynamic: {d_data.shape}")
16print(f" Future : {f_data.shape}, Target : {targets_s.shape}")
Step 2.5: Train/Validation Split of Sequences¶
Split the generated sequence arrays.
1val_split_s = 0.2
2n_samples_s = s_data.shape[0]
3split_idx_s = int(n_samples_s * (1 - val_split_s))
4
5X_s_train, X_s_val = s_data[:split_idx_s], s_data[split_idx_s:]
6X_d_train, X_d_val = d_data[:split_idx_s], d_data[split_idx_s:]
7X_f_train, X_f_val = f_data[:split_idx_s], f_data[split_idx_s:]
8y_t_train, y_t_val = targets_s[:split_idx_s], targets_s[split_idx_s:]
9
10train_inputs_s = [X_s_train, X_d_train, X_f_train]
11val_inputs_s = [X_s_val, X_d_val, X_f_val]
12print(f"Data split. Train sequences: {len(y_t_train)}")
Step 2.6: Define and Train Stricter TFT Model¶
Instantiate the stricter TFT class, providing all three input dimensions and the quantiles list.
1quantiles_s = [0.1, 0.5, 0.9]
2model_strict = TFTStricter( # Using the aliased stricter TFT
3 static_input_dim=s_data.shape[-1],
4 dynamic_input_dim=d_data.shape[-1],
5 future_input_dim=f_data.shape[-1],
6 forecast_horizon=forecast_horizon_s,
7 quantiles=quantiles_s,
8 output_dim=1, # Univariate target
9 hidden_units=16, num_heads=2, num_lstm_layers=1, lstm_units=16
10)
11print("\nStricter TFT model for quantiles instantiated.")
12
13loss_fn_s = combined_quantile_loss(quantiles=quantiles_s)
14model_strict.compile(optimizer='adam', loss=loss_fn_s)
15print("Stricter TFT compiled with quantile loss.")
16
17print("\nStarting stricter TFT training (quantile)...")
18history_s = model_strict.fit(
19 train_inputs_s, # Must be [Static, Dynamic, Future]
20 y_t_train,
21 validation_data=(val_inputs_s, y_t_val),
22 epochs=5, batch_size=16, verbose=0
23)
24print("Stricter TFT training finished.")
25if history_s and history_s.history.get('val_loss'):
26 val_loss_s = history_s.history['val_loss'][-1]
27 print(f"Final validation loss (stricter TFT): {val_loss_s:.4f}")
Step 2.7: Make Predictions and Visualize (Stricter TFT)¶
Predictions and visualization follow a similar pattern.
1print("\nMaking quantile predictions (stricter TFT)...")
2val_predictions_s = model_strict.predict(val_inputs_s, verbose=0)
3print(f"Prediction output shape: {val_predictions_s.shape}")
4
5# Inverse transform (assuming 'TargetValue' was scaled by scaler_s)
6# For simplicity, visualization of inverse transformed values is omitted here
7# but would follow the same logic as Example 1, using scaler_s.
8
9# Plot one sample from validation set
10sample_to_plot_s = 0
11actual_s = y_t_val[sample_to_plot_s, :, 0] # Scaled
12pred_q_s = val_predictions_s[sample_to_plot_s, :, :] # Scaled
13
14# Create a dummy time axis for this sample's forecast
15plot_time_axis_s = np.arange(forecast_horizon_s)
16
17plt.figure(figsize=(12, 6))
18plt.plot(plot_time_axis_s, actual_s, label='Actual (Scaled)',
19 marker='o', linestyle='--')
20plt.plot(plot_time_axis_s, pred_q_s[:, 1], # Median
21 label='Predicted Median (q=0.5, Scaled)', marker='x')
22plt.fill_between(
23 plot_time_axis_s, pred_q_s[:, 0], pred_q_s[:, 2],
24 color='gray', alpha=0.3,
25 label='Prediction Interval (q=0.1 to q=0.9, Scaled)'
26)
27plt.title('Stricter TFT Quantile Forecast (Validation Sample - Scaled)')
28plt.xlabel('Forecast Step'); plt.ylabel('Scaled Value')
29plt.legend(); plt.grid(True); plt.tight_layout()
30# plt.savefig("docs/source/images/forecasting_quantile_tft_stricter.png")
31plt.show()
32print("Stricter TFT quantile plot generated.")
Example Output Plot (Stricter TFT - Scaled Values):
Visualization of the quantile forecast using the stricter TFT model (showing scaled values for simplicity).¶