Temporal Fusion Transformer (TFT) and Variants¶
This section of the user guide provides a detailed overview of the
Temporal Fusion Transformer (TFT) architecture and its various
implementations available within the fusionlab-learn library. The
TFT is a landmark architecture for multi-horizon time series
forecasting, renowned for its ability to deliver high performance while
maintaining interpretability.
Its core innovation lies in its hybrid design, which skillfully combines recurrent layers for processing local sequences, self-attention for capturing long-range dependencies, and gating mechanisms to intelligently filter information from static, dynamic, and future-known inputs.
fusionlab-learn offers several implementations of the TFT, each
tailored for different use cases and input requirements:
TemporalFusionTransformer: The primary, flexible, general-purpose implementation that can gracefully handle scenarios where some input types (like static or future features) might be missing.
TFT (Stricter) & DummyTFT: More specialized variants that enforce rigid input patterns (e.g., requiring all inputs, or only static and dynamic), which can be useful for ensuring data integrity in well-defined pipelines.
This guide details each of these models, their unique features, and their typical applications.
TemporalFusionTransformer¶
- API Reference:
TemporalFusionTransformer
The TemporalFusionTransformer class is the primary, flexible
implementation of the Temporal Fusion Transformer architecture
[1] within fusionlab. It is designed to handle a variety of
input configurations and forecasting tasks.
Key Features:
Flexible Inputs: Supports dynamic (past) inputs (required), optional static inputs (time-invariant metadata), and optional future inputs (known exogenous covariates). The model adapts internally based on which inputs are provided.
Multi-Horizon Forecasting: Directly outputs predictions for multiple future time steps defined by the
forecast_horizonparameter.Probabilistic Forecasts: Can generate quantile forecasts by specifying the
quantilesparameter (e.g.,[0.1, 0.5, 0.9]) to estimate prediction intervals and uncertainty. IfquantilesisNone, it produces deterministic (point) forecasts.Standard TFT Components: Built using core TFT blocks like
VariableSelectionNetwork(optional for static/future), LSTM encoders, Static Enrichment (usingGatedResidualNetwork), Temporal Self-Attention (TemporalAttentionLayer), andGatedResidualNetwork.
When to Use:
This is generally the recommended starting point for applying the TFT architecture. Use it when:
You need a standard, well-understood TFT implementation.
You have various combinations of dynamic, static, or future inputs (and don’t necessarily have all three).
You require either point forecasts or probabilistic (quantile) forecasts.
Code Example (Instantiation and Call)
This example shows how to instantiate the flexible
TemporalFusionTransformer with different combinations of inputs.
1import numpy as np
2import tensorflow as tf
3from fusionlab.nn import TemporalFusionTransformer # Flexible version
4
5# --- Dummy Data Dimensions ---
6B, T_past, H = 4, 12, 6 # Batch, Lookback, Forecast Horizon
7D_dyn, D_stat, D_fut = 5, 3, 2 # Feature dimensions
8T_future_total = T_past + H # Future inputs span lookback + horizon
9
10# Create Dummy Input Tensors
11static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
12dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
13future_in = tf.random.normal((B, T_future_total, D_fut), dtype=tf.float32)
14
15# --- Example 1: Dynamic Inputs Only (Point Forecast) ---
16print("--- Example 1: Dynamic Only ---")
17model_dyn_only = TemporalFusionTransformer(
18 dynamic_input_dim=D_dyn,
19 static_input_dim=None, # Explicitly None
20 future_input_dim=None, # Explicitly None
21 forecast_horizon=H,
22 output_dim=1, # Specify output dimension
23 hidden_units=8, num_heads=1 # Small params for example
24)
25# Input is a list: [Static, Dynamic, Future]
26# For dynamic only, static and future are None
27inputs1 = [None, dynamic_in, None]
28try:
29 output1 = model_dyn_only(inputs1, training=False)
30 print(f"Input: [None, Dynamic ({dynamic_in.shape}), None]")
31 print(f"Output Shape (Point): {output1.shape}")
32 # Expected: (B, H, OutputDim=1) -> (4, 6, 1)
33except Exception as e:
34 print(f"Call failed for Dynamic Only: {e}")
35
36
37# --- Example 2: All Inputs (Quantile Forecast) ---
38print("\n--- Example 2: All Inputs (Quantile) ---")
39my_quantiles = [0.1, 0.5, 0.9]
40model_all_inputs = TemporalFusionTransformer(
41 dynamic_input_dim=D_dyn,
42 static_input_dim=D_stat, # Provide static dim
43 future_input_dim=D_fut, # Provide future dim
44 forecast_horizon=H,
45 quantiles=my_quantiles, # Set quantiles
46 output_dim=1, # Specify output dimension
47 hidden_units=8, num_heads=1
48)
49# Input is a list [Static, Dynamic, Future]
50inputs2 = [static_in, dynamic_in, future_in]
51try:
52 output2 = model_all_inputs(inputs2, training=False)
53 print(f"Input: [Stat ({static_in.shape}), Dyn ({dynamic_in.shape}), "
54 f"Fut ({future_in.shape})]")
55 print(f"Output Shape (Quantile): {output2.shape}")
56 # Expected: (B, H, NumQuantiles) if output_dim=1 -> (4, 6, 3)
57except Exception as e:
58 print(f"Call failed for All Inputs: {e}")
Important
Input Data Order and Format for TemporalFusionTransformer
The flexible TemporalFusionTransformer expects its inputs argument
in the call method (and subsequently for .fit(), .predict())
to be a list or tuple of three elements:
[static_features, dynamic_features, future_features].
If a particular input type is not used (e.g., no static features), pass
Nonefor that element in the list. For example:Dynamic only:
[None, dynamic_array, None]Dynamic + Static:
[static_array, dynamic_array, None]Dynamic + Future:
[None, dynamic_array, future_array]All three:
[static_array, dynamic_array, future_array]
The
dynamic_featuresinput is always required.The model’s
__init__parameters (static_input_dim, dynamic_input_dim, future_input_dim) determine which of these inputs are expected to be actual tensors versusNone.This order is handled by the internal validation logic (
validate_model_inputs()whenmodel_name='tft_flex').
Formulation¶
Here, we describe the core mathematical concepts behind the Temporal Fusion Transformer, following the architecture outlined in the original paper [1]. This provides insight into how different inputs are processed and transformed to generate forecasts.
Notation:
- Inputs:
\(s \in \mathbb{R}^{d_s}\): Static (time-invariant) covariates.
\(z_t \in \mathbb{R}^{d_z}\): Known future inputs at time \(t\).
\(x_t \in \mathbb{R}^{d_x}\): Observed past dynamic inputs at time \(t\).
\(y_t \in \mathbb{R}^{d_y}\): Past target variable(s) at time \(t\) (often included in \(x_t\)).
- Time Indices:
\(t \in [T-k+1, T]\): Past time steps within the lookback window of size \(k\).
\(t \in [T+1, T+\tau]\): Future time steps for the forecast horizon \(\tau\).
- Dimensions:
\(d_s, d_x, d_z, d_y\): Dimensionalities of respective inputs.
\(d_{model}\): The main hidden state dimension of the model (e.g.,
hidden_units).
- Common Functions:
\(LN(\cdot)\): Layer Normalization.
\(\sigma(\cdot)\): Sigmoid activation function.
\(ReLU(\cdot), ELU(\cdot)\): Activation functions.
\(Linear(\cdot)\): A dense (fully-connected) layer.
\(GLU(a, b) = a \odot \sigma(b)\): Gated Linear Unit, where \(\odot\) is element-wise multiplication.
\(GRN(a, [c])\): Gated Residual Network. A key block roughly defined as: \(GRN(a, c) = LN(a' + GLU(Linear_1(act(Linear_0(a'))), Linear_2(a')))\), where \(a' = a+Linear_c(c)\) if context \(c\) is provided, else \(a'=a\).
Architectural Flow:
Input Transformations & Variable Selection: Inputs (categorical/continuous) are transformed into numerical vectors (e.g., via embeddings or linear layers). Variable Selection Networks (VSNs) are applied to each input type (static \(s\), past dynamic \(x_t\), known future \(z_t\)), potentially conditioned on static context \(c_s\).
VSN computes feature weights \(\alpha_\chi\) and applies feature-wise GRNs (\(\tilde{\chi}^j = GRN(\chi^j)\)).
Output is a weighted sum: \(\xi = \sum_{j} \alpha_\chi^j \tilde{\chi}^j\).
This yields embeddings: static \(\zeta\), past dynamic \(\xi_t\) (\(t \le T\)), and future \(\xi_t\) (\(t > T\)).
Static Covariate Encoders: The static embedding \(\zeta\) is processed through dedicated GRNs to produce four context vectors for conditioning different parts of the temporal processing: \(c_s\) (VSN context), \(c_e\) (enrichment context), \(c_h\) (LSTM initial hidden state), \(c_c\) (LSTM initial cell state). E.g., \(c_s = GRN_{vs}(\zeta)\).
Locality Enhancement (LSTM Encoder): The sequence of combined past and future VSN embeddings \(\{\xi_t\}_{t=T-k+1}^{T+\tau}\) is fed into a sequence processing layer (typically multi-layer LSTM), initialized with contexts \(c_h, c_c\). \((h_t, cell_t) = LSTM((h_{t-1}, cell_{t-1}), \xi_t)\). The output is a sequence of hidden states \(\{h_t\}\).
Static Enrichment: The LSTM output sequence \(\{h_t\}\) is enriched with static context \(c_e\) using another GRN applied time-wise: \(\phi_t = GRN_{enrich}(h_t, c_e)\).
Temporal Self-Attention: An interpretable multi-head attention mechanism processes the enriched sequence \(\{\phi_t\}\). The static context \(c_s\) may condition the query generation or internal GRNs. It computes attention weights over past time steps relative to the current forecast time step.
Attention Calculation (Simplified): Weights \(\alpha_t^{(h)}\) for head \(h\) at step \(t\) are computed via scaled dot-product attention, typically using \(\phi_t\) to form Queries and \(\{\phi_{t'}\}_{t' \le T}\) to form Keys and Values. \(\alpha_t^{(h)} = \text{Softmax}\left( \dots \right)\).
Output & Gating: The attention output \(Attn_t\) is combined with \(\phi_t\) using gating (GLU) and a residual connection, followed by Layer Normalization: \(\beta_t = LN( \phi_t + GLU(..., Attn_t))\).
Position-wise Feed-forward: The attention output \(\beta_t\) is processed by another GRN applied independently at each time step: \(\delta_t = GRN_{final}(\beta_t)\).
Output Layer: The final features corresponding to the forecast horizon \(\{\delta_t\}_{t=T+1}^{T+\tau}\) are passed through linear layers to produce predictions.
Quantiles: Separate linear layers for each quantile \(q\):
\[\hat{y}_{t, q} = Linear_q(\delta_t)\]Point: A single linear layer: \(\hat{y}_t = Linear_{point}(\delta_t)\).
This detailed flow illustrates how TFT integrates various components to handle diverse inputs, capture temporal patterns, incorporate static context, and generate interpretable multi-horizon forecasts with uncertainty estimates.
TFT (Stricter Implementation - All Inputs Required)¶
- API Reference:
(Note: This refers to a specific TFT class implementation within
fusionlab that enforces stricter input requirements compared to the
more flexible TemporalFusionTransformer described above. It assumes
static, dynamic past, and known future inputs are always provided and
are not None).
This class implements the Temporal Fusion Transformer (TFT) architecture, closely following the structure described in the original paper [1]. It is designed for multi-horizon time series forecasting and explicitly requires static covariates, dynamic (historical) covariates, and known future covariates as inputs.
Compared to implementations allowing optional inputs, this version
mandates all input types, simplifying the internal input handling
logic. It incorporates key TFT components like
VariableSelectionNetwork (VSNs),
GatedResidualNetwork (GRNs) for
static context generation and feature processing, LSTM encoding,
static enrichment, interpretable multi-head attention
(TemporalAttentionLayer),
and position-wise feedforward layers.
Use Case and Importance
This TFT class provides a structured implementation useful when all feature types (static, dynamic past, known future) are readily available and adherence to the paper’s component structure (like distinct static contexts) is desired. It serves as a strong baseline for complex forecasting tasks demanding interpretability and handling of heterogeneous data. Its requirement for all inputs simplifies the call method, making the internal flow potentially easier to follow.
Parameters
dynamic_input_dim (int): Total number of features in the dynamic (past) input tensor.
static_input_dim (int): Total number of features in the static (time-invariant) input tensor.
future_input_dim (int): Total number of features in the known future input tensor.
hidden_units (int, default: 32): Main dimensionality of hidden layers (VSNs, GRNs, Attention).
num_heads (int, default: 4): Number of attention heads in the Temporal Attention Layer.
dropout_rate (float, default: 0.1): Dropout rate for non-recurrent connections (0 to 1).
recurrent_dropout_rate (float, default: 0.0): Dropout rate for LSTM recurrent connections (0 to 1).
forecast_horizon (int, default: 1): Number of future time steps to predict.
quantiles (Optional[List[float]], default: None): List of quantiles (e.g., [0.1, 0.5, 0.9]) for probabilistic forecasting. If None, performs point forecasting.
activation (str, default: ‘elu’): Activation function for GRNs (e.g., ‘relu’, ‘gelu’).
use_batch_norm (bool, default: False): If True, use Batch Normalization in GRNs.
num_lstm_layers (int, default: 1): Number of stacked LSTM layers in the encoder.
lstm_units (Optional[Union[int, List[int]]], default: None): Units per LSTM layer. If int, used for all layers. If list, length must match num_lstm_layers. Defaults to hidden_units.
output_dim (int, default: 1): Number of target variables predicted per step.
Notes
Input Format: This implementation requires inputs to the call method as a list or tuple containing exactly three tensors in the order:
[static_inputs, dynamic_inputs, future_inputs]. Expected shapes are generally:static_inputs: \((B, D_s)\)
dynamic_inputs: \((B, T_{past}, D_{dyn})\)
future_inputs: \((B, T_{future}, D_{fut})\) (Note: The required length \(T_{future}\) depends on how inputs are combined internally before the LSTM. Ensure data preparation aligns, e.g., using
reshape_xtft_data()).
Categorical Features: This implementation assumes inputs are numeric. Handling categorical features requires modifications (e.g., adding embedding layers before VSNs).
Formulation¶
(This section describes the flow assuming numeric inputs)
Variable Selection: Separate
VariableSelectionNetwork(VSNs) process static (\(\mathbf{s}\)), dynamic past (\(\mathbf{x}_t\)), and known future (\(\mathbf{z}_t\)) inputs, potentially conditioned by static context (\(c_s\)). Outputs: \(\zeta\), \(\xi^{dyn}_t\), \(\xi^{fut}_t\).Static Context Generation: Four distinct
GatedResidualNetwork(GRNs) process the static VSN output \(\zeta\) to produce context vectors: \(c_s\) (for VSNs), \(c_e\) (for enrichment), \(c_h\) (LSTM initial hidden state), \(c_c\) (LSTM initial cell state).Temporal Processing Input: Selected dynamic (\(\xi^{dyn}_t\)) and future (\(\xi^{fut}_t\)) embeddings are combined (e.g., concatenated along features) and augmented with
PositionalEncoding(\(\psi_t\)).LSTM Encoder: A stack of LSTMs processes \(\psi_t\), initialized with \([c_h, c_c]\), outputting hidden states \(\{h_t\}\).
\[\{h_t\} = \text{LSTMStack}(\{\psi_t\}, \text{init}=[c_h, c_c])\]Static Enrichment: A time-distributed GRN combines LSTM outputs \(h_t\) with the static enrichment context \(c_e\).
\[\phi_t = GRN_{enrich}(h_t, c_e)\]Temporal Self-Attention:
TemporalAttentionLayerprocesses the enriched sequence \(\{\phi_t\}\) using \(c_s\) as context, outputting \(\beta_t\) after internal gating/residuals.\[\beta_t = \text{TemporalAttention}(\{\phi_t\}, c_s)\]Position-wise Feed-Forward: A final time-distributed GRN processes \(\beta_t\).
\[\delta_t = GRN_{final}(\beta_t)\]Output Projection: Features for the forecast horizon (\(t > T\)) are selected from \(\{\delta_t\}\) (typically the last \(H\) steps) and passed through output
Denselayer(s) for point (\(\hat{y}_t\)) or quantile (\(\hat{y}_{t, q}\)) predictions.
Code Example (Instantiation & Call):
1import numpy as np
2import tensorflow as tf
3from fusionlab.nn.transformers import TFT
4
5# Dummy Data Dimensions
6B, T_past, H = 4, 12, 6
7D_dyn, D_stat, D_fut = 5, 3, 2
8T_future = T_past + H # Example: Future covers lookback + horizon
9
10# Create Dummy Input Tensors (ALL REQUIRED)
11static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
12dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
13future_in = tf.random.normal((B, T_future, D_fut), dtype=tf.float32)
14
15# Instantiate the revised TFT Model (Point Forecast)
16model = TFT(
17 dynamic_input_dim=D_dyn,
18 static_input_dim=D_stat,
19 future_input_dim=D_fut,
20 forecast_horizon=H,
21 hidden_units=16,
22 num_heads=2,
23 quantiles=None # Point forecast
24)
25
26# Prepare input list in REQUIRED order: [static, dynamic, future]
27model_inputs = [static_in, dynamic_in, future_in]
28
29# Call the model (builds on first call)
30predictions = model(model_inputs)
31
32print(f"Input Shapes: S={static_in.shape}, D={dynamic_in.shape}, F={future_in.shape}")
33print(f"Output shape (Point): {predictions.shape}")
34# Expected: (B, H, OutputDim=1) -> (4, 6, 1)
DummyTFT (Static & Dynamic Inputs Only)¶
- API Reference:
The DummyTFT (formerly NTemporalFusionTransformer) is a
variant of the TFT model available in fusionlab. It is
characterized by its specific input requirements, focusing on scenarios
where only static and dynamic (past) features are available.
Key Features & Differences:
Mandatory Static & Dynamic Inputs: This class requires both
static_input_dimanddynamic_input_dimto be specified during initialization and expects corresponding non-None static and dynamic (past) tensors as input.No Future Inputs Used: This variant is designed specifically for scenarios where known future covariates are not available or not utilized. The architecture omits pathways for processing future inputs.
Point or Quantile Forecasts: Can produce deterministic (point) forecasts or probabilistic (quantile) forecasts if the
quantilesparameter is specified.Core TFT Architecture: It utilizes fundamental TFT components like
VariableSelectionNetwork(VSNs for static and dynamic inputs), LSTM encoders, Static Enrichment, Temporal Self-Attention (TemporalAttentionLayer), andGatedResidualNetwork(GRNs), configured for its two-input structure.
When to Use:
Consider using DummyTFT primarily when:
Your forecasting problem involves only static metadata and dynamic (past) observed features.
You explicitly do not have or require known future covariates.
You need point or quantile forecasts based on this two-input setup.
Formulation¶
The DummyTFT follows the core mathematical principles of the
standard Temporal Fusion Transformer [1], employing components like
VSNs, static context GRNs, LSTM encoding, static enrichment,
temporal self-attention, and position-wise feed-forward GRNs.
The main distinctions in the formulation are:
No Future Input Path: The architecture omits the processing pathway for known future inputs (\(z_t\)). VSNs are not applied to them, and they are not included in the sequence fed to the LSTM or attention mechanisms. Only static (\(s\)) and past dynamic (\(x_t\)) inputs are processed.
Output Layer: The final output layer processes features derived from static and dynamic inputs to produce point or quantile predictions for the forecast horizon.
Code Example:
1import numpy as np
2import tensorflow as tf
3from fusionlab.nn.transformers import DummyTFT
4
5# Dummy Data Dimensions
6B, T_past, H = 4, 12, 6 # Batch, Lookback, Horizon
7D_dyn, D_stat = 5, 3 # Dynamic, Static feature dimensions
8output_dim = 1 # Univariate target
9
10# Create Dummy Input Tensors (Static and Dynamic ONLY)
11static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
12dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
13
14# Instantiate the DummyTFT Model (Point Forecast)
15model_point = DummyTFT(
16 static_input_dim=D_stat,
17 dynamic_input_dim=D_dyn,
18 forecast_horizon=H,
19 output_dim=output_dim,
20 hidden_units=16, num_heads=2,
21 quantiles=None # Point forecast
22)
23
24# Prepare input list: [static, dynamic]
25model_inputs_point = [static_in, dynamic_in]
26
27# Call the model
28try:
29 predictions_point = model_point(model_inputs_point, training=False)
30 print("--- DummyTFT Point Forecast ---")
31 print(f"Input Shapes: S={static_in.shape}, D={dynamic_in.shape}")
32 print(f"Output shape (Point): {predictions_point.shape}")
33 # Expected: (B, H, O) -> (4, 6, 1)
34except Exception as e:
35 print(f"DummyTFT (Point) call failed: {e}")
36
37# Instantiate for Quantile Forecast
38my_quantiles = [0.2, 0.5, 0.8]
39model_quant = DummyTFT(
40 static_input_dim=D_stat,
41 dynamic_input_dim=D_dyn,
42 forecast_horizon=H,
43 output_dim=output_dim,
44 quantiles=my_quantiles,
45 hidden_units=16, num_heads=2
46)
47model_inputs_quant = [static_in, dynamic_in]
48try:
49 predictions_quant = model_quant(model_inputs_quant, training=False)
50 print("\n--- DummyTFT Quantile Forecast ---")
51 print(f"Output shape (Quantile): {predictions_quant.shape}")
52 # Expected for output_dim=1: (B, H, NumQuantiles) -> (4, 6, 3)
53except Exception as e:
54 print(f"DummyTFT (Quantile) call failed: {e}")
Next Steps¶
Note
You now understand the theory and the complete workflow for
TFT and its variants, you can proceed to the exercises for more hands-on practice:
References