Hybrid Transformer Models: XTFT & SuperXTFT¶
This section of the user guide covers the XTFT (Extreme Temporal
Fusion Transformer) family of models. These are advanced, hybrid
architectures designed for the most demanding multi-horizon time
series forecasting tasks.
Building upon the foundational concepts of the original Temporal Fusion Transformer (TFT), these models integrate multi-scale recurrent processing using LSTMs with a sophisticated, multi-layered attention framework. This hybrid approach allows them to capture an exceptionally rich set of temporal patterns, from short-term dependencies to very long-range, complex interactions.
This guide details two models in this family:
XTFT: The main, stable implementation, which includes numerous enhancements over the standard TFT, such as advanced attention mechanisms and integrated anomaly detection.
SuperXTFT: An experimental variant of
XTFTthat introduces additional feature selection and processing layers.
XTFT (Extreme Temporal Fusion Transformer)¶
- API Reference:
The XTFT model represents a significant evolution of the Temporal
Fusion Transformer, designed to tackle highly complex time series
forecasting tasks with enhanced capabilities for representation
learning, multi-scale analysis, and integrated anomaly detection.
Key Features:
Advanced Input Handling: Requires static, dynamic (past), and future known inputs. Utilizes components like
LearnedNormalizationandMultiModalEmbeddingfor input processing. Note: Unlike the revised TFT, XTFT internally uses these components and doesn’t rely on VSNs directly at the input stage.Multi-Scale Temporal Processing: Employs
MultiScaleLSTMto analyze temporal dependencies at different user-defined resolutions (viascales). Output aggregation is handled byaggregate_multiscale().Sophisticated Attention Mechanisms: Incorporates multiple specialized attention layers for richer context modeling:
Dynamic Temporal Focus: Uses a
DynamicTimeWindowcomponent to potentially focus on the most relevant recent time steps before final aggregation.Flexible Aggregation: Aggregates final temporal features using different strategies (
final_aggparameter, handled byaggregate_time_window_output()).Integrated Anomaly Detection: Offers multiple strategies (via
anomaly_detection_strategyparameter) for incorporating anomaly information into the training process:‘feature_based’: Learns anomaly scores from internal features using dedicated attention/scoring layers.
‘prediction_based’: Calculates anomaly scores based on prediction errors using a specialized loss function (
prediction_based_loss()).‘from_config’: Uses pre-computed anomaly scores provided via the
anomaly_configdictionary, integrated into the loss viaAnomalyLossand potentiallycombined_total_loss(). The contribution of anomaly loss is controlled byanomaly_loss_weight.
Flexible Output: Features a
MultiDecoder(generating horizon-specific features) andQuantileDistributionModelinglayer to produce multi-horizon forecasts for specifiedquantiles(or point forecasts ifquantilesisNone).
When to Use XTFT¶
XTFT is designed for challenging forecasting problems where:
Underlying temporal dynamics are highly complex and potentially span multiple time scales.
Rich static, dynamic, and future information needs to be integrated effectively using advanced fusion techniques.
Capturing long-range dependencies is important (leveraging memory attention).
Identifying or accounting for anomalies within the time series is a requirement.
Maximum predictive performance is desired, potentially at the cost of increased model complexity and computational resources compared to standard TFT.
Formulation¶
XTFT significantly extends the standard TFT architecture. While it builds upon core concepts like GRNs and attention, it introduces many specialized components. We highlight the key additions and modifications here. For full details, please refer to the source code and the documentation of individual components (linked above).
Input Processing:
Static inputs (\(s\)) undergo
LearnedNormalizationand are processed by internal GRNs/Dense layers (static_dense, static_dropout, grn_static).Dynamic (\(x_t\)) and Future (\(z_t\)) inputs are jointly processed by
MultiModalEmbedding.PositionalEncodingis added.Optional residual connections enhance gradient flow.
Multi-Scale LSTM:
Dynamic inputs (\(x_t\) or embeddings derived from them) are processed by
MultiScaleLSTMusing different temporalscales.Outputs are aggregated (e.g., ‘last’ step) into lstm_features.
Advanced Attention Layers:
HierarchicalAttentionprocesses dynamic and future inputs.CrossAttentionmodels interactions between dynamic inputs and combined embeddings.MemoryAugmentedAttentionuses hierarchical attention output to query an external memory.GRNs are applied after each attention block (grn_attention_*).
Feature Fusion:
Processed static features, aggregated lstm_features, and outputs
from the various attention mechanisms are concatenated. *
MultiResolutionAttentionFusionis applied to integrate these diverse feature streams.Dynamic Windowing & Aggregation:
DynamicTimeWindowselects recent time steps from the fused features.aggregate_time_window_output()collapses the time dimension based on final_agg strategy.
Decoding and Output:
MultiDecodertransforms the aggregated features for each horizon step.A final GRN pipeline (grn_decoder) processes decoder outputs.
QuantileDistributionModelingmaps these features to the final quantile or point predictions (\(\hat{y}_{t, q}\) / \(\hat{y}_t\)).
Anomaly Detection Integration:
Feature-Based: Internal anomaly_attention, anomaly_projection, and anomaly_scorer layers compute anomaly_scores during the forward pass.
Config-Based: Pre-computed anomaly_scores are provided via anomaly_config.
Loss Calculation: If anomaly_scores exist,
AnomalyLosscalculates an anomaly term, which is added viamodel.add_loss(used in feature/config modes).Prediction-Based: A specialized combined loss function is used during compile, and the custom train_step handles calculations.
Code Example (Instantiation):
1import numpy as np
2# Assuming XTFT is importable
3from fusionlab.nn.models import XTFT
4
5# Example Configuration
6static_dim, dynamic_dim, future_dim = 5, 7, 3
7horizon = 12
8output_dim = 1
9my_quantiles = [0.1, 0.5, 0.9]
10my_scales = [1, 3, 6] # Example scales for MultiScaleLSTM
11
12# Instantiate XTFT with various parameters
13xtft_model = XTFT(
14 static_input_dim=static_dim,
15 dynamic_input_dim=dynamic_dim,
16 future_input_dim=future_dim,
17 forecast_horizon=horizon,
18 quantiles=my_quantiles,
19 output_dim=output_dim,
20 embed_dim=16,
21 hidden_units=32,
22 attention_units=16,
23 lstm_units=32,
24 num_heads=4,
25 scales=my_scales,
26 multi_scale_agg='last', # Aggregation for MultiScaleLSTM
27 memory_size=50,
28 max_window_size=24, # For DynamicTimeWindow
29 final_agg='average', # Aggregation after DynamicTimeWindow
30 anomaly_detection_strategy='prediction_based', # Example strategy
31 anomaly_loss_weight=0.05,
32 dropout_rate=0.1
33)
34
35# Build the model (e.g., by providing dummy input shapes)
36# Note: Actual shapes depend on data preprocessing
37dummy_batch_size = 4
38dummy_time_steps = 24 # Should match or exceed max_window_size
39
40# Example shapes (adjust T_future as needed)
41static_shape = (dummy_batch_size, static_dim)
42dynamic_shape = (dummy_batch_size, dummy_time_steps, dynamic_dim)
43future_shape = (dummy_batch_size, dummy_time_steps + horizon, future_dim)
44
45# Build using dummy shapes (or use model.fit/predict later)
46# xtft_model.build(input_shape=[static_shape, dynamic_shape, future_shape])
47# print("XTFT Model Built (example).")
48
49xtft_model.summary() # Display model architecture summary (after build)
SuperXTFT: An Enhanced Hybrid Transformer¶
- API Reference:
The SuperXTFT is the most advanced and powerful implementation in the
TFT family available in fusionlab-learn. It inherits the entire
robust feature set of the standard XTFT
and enhances it with two significant architectural modifications, designed
to maximize representation learning and predictive performance.
It should be considered the expert choice for tackling the most complex forecasting problems where fine-grained feature selection and deep contextual processing are paramount.
Key Architectural Enhancements (from XTFT)¶
SuperXTFT improves upon the standard XTFT architecture in two
primary ways:
1. Integrated Input Variable Selection (VSNs)
Unlike the standard XTFT which processes inputs directly into
embeddings, SuperXTFT first passes all three raw input streams
(static, dynamic past, and future) through their own dedicated
VariableSelectionNetwork (VSN) layers.
Benefit: This allows the model to learn the relative importance of each input feature at the very beginning of the pipeline, before they are mixed and processed by downstream components. This can lead to more robust and interpretable feature representations, especially in datasets with a large number of potentially redundant or noisy features. The selected features (\(\mathbf{s}', \mathbf{x}'_t, \mathbf{z}'_t\)) are then fed into the rest of the standard XTFT architecture.
2. Post-Component Gated Processing (GRNs)
SuperXTFT strategically inserts additional
GatedResidualNetwork (GRN) layers
immediately after each major attention and decoder block. A GRN is
applied to the outputs of:
Hierarchical Attention
Cross-Attention
Memory-Augmented Attention
The Multi-Decoder layer
Benefit: This adds another layer of deep, non-linear processing and feature gating at critical junctures within the architecture. It allows the model to further refine the contextual representations generated by each attention mechanism before they are fused together, potentially capturing more complex and subtle interactions.
When to Use SuperXTFT¶
SuperXTFT is the recommended choice for challenging forecasting
problems where:
You are working with a large number of input features of varying importance and want the model to learn which ones to prioritize (leveraging the input VSNs).
You hypothesize that there are complex, non-linear interactions between the different contexts (static, temporal, memory) that could benefit from the additional deep processing offered by the post-component GRNs.
You are aiming for maximum predictive performance and have the computational resources for a deeper, more parameter-rich model.
For standard use cases, the
XTFTremains a powerful and efficient baseline.
Code Example¶
The instantiation of SuperXTFT is identical to XTFT. The
additional VSN and GRN layers are created and integrated automatically
within the model’s constructor and forward pass.
1import numpy as np
2from fusionlab.nn.models import SuperXTFT
3
4# Configuration is the same as for XTFT
5static_dim, dynamic_dim, future_dim = 5, 7, 3
6horizon = 12
7output_dim = 1
8
9# Instantiate the SuperXTFT model
10super_xtft_model = SuperXTFT(
11 static_input_dim=static_dim,
12 dynamic_input_dim=dynamic_dim,
13 future_input_dim=future_dim,
14 forecast_horizon=horizon,
15 output_dim=output_dim,
16 # Other architectural parameters
17 hidden_units=32,
18 num_heads=4,
19 lstm_units=32,
20 attention_units=16
21)
22
23print("SuperXTFT model instantiated successfully.")
24# You can view the deeper architecture with .summary() after building
25# super_xtft_model.summary(line_length=110)
Note
SuperXTFT is a production-ready model and represents the most
powerful, feature-rich version in the TFT family.
It also serves as a development platform where new, cutting-edge
features may be introduced first. Future releases might include
experimental options aimed at lightening the architecture or
improving computational efficiency. While any such new features may be
subject to change, the core SuperXTFT architecture is stable,
has been thoroughly tested, and can be confidently used in
production environments where maximum performance is desired.
Next Steps¶
Note
You now have a deep understanding of the theory and architecture
of the XTFT and SuperXTFT models. To apply these concepts,
you can proceed to the hands-on exercises:
References