# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""Implements the Temporal Fusion Transformer (TFT), a state-of-the-art
architecture for multi-horizon time-series forecasting.
"""
from numbers import Real, Integral
from typing import List, Optional, Union
from ..._fusionlog import fusionlog
from ...api.property import NNLearner
from ...core.checks import is_iterable
from ...core.diagnose_q import validate_quantiles
from ...compat.sklearn import validate_params, Interval, StrOptions
from ...utils.deps_utils import ensure_pkg
from ...utils.validator import validate_positive_integer
from .. import KERAS_DEPS, KERAS_BACKEND, dependency_message
if KERAS_BACKEND:
from .._tensor_validation import (
validate_model_inputs, combine_temporal_inputs_for_lstm
)
from ..losses import combined_quantile_loss
from ..components import (
VariableSelectionNetwork,
PositionalEncoding,
GatedResidualNetwork,
TemporalAttentionLayer,
)
LSTM = KERAS_DEPS.LSTM
LSTMCell=KERAS_DEPS.LSTMCell
LayerNormalization = KERAS_DEPS.LayerNormalization
TimeDistributed = KERAS_DEPS.TimeDistributed
MultiHeadAttention = KERAS_DEPS.MultiHeadAttention
Model = KERAS_DEPS.Model
BatchNormalization = KERAS_DEPS.BatchNormalization
Input = KERAS_DEPS.Input
Softmax = KERAS_DEPS.Softmax
Flatten = KERAS_DEPS.Flatten
Dropout = KERAS_DEPS.Dropout
Dense = KERAS_DEPS.Dense
Embedding =KERAS_DEPS.Embedding
Concatenate=KERAS_DEPS.Concatenate
Layer = KERAS_DEPS.Layer
register_keras_serializable=KERAS_DEPS.register_keras_serializable
tf_reduce_sum =KERAS_DEPS.reduce_sum
tf_stack =KERAS_DEPS.stack
tf_expand_dims =KERAS_DEPS.expand_dims
tf_tile =KERAS_DEPS.tile
tf_range =KERAS_DEPS.range
tf_rank = KERAS_DEPS.rank
tf_squeeze= KERAS_DEPS.squeeze
tf_concat =KERAS_DEPS.concat
tf_shape =KERAS_DEPS.shape
tf_zeros=KERAS_DEPS.zeros
tf_float32=KERAS_DEPS.float32
tf_reshape=KERAS_DEPS.reshape
tf_autograph=KERAS_DEPS.autograph
tf_multiply=KERAS_DEPS.multiply
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_get_static_value=KERAS_DEPS.get_static_value
tf_gather=KERAS_DEPS.gather
DEP_MSG = dependency_message('transformers.tft')
logger = fusionlog().get_fusionlab_logger(__name__)
__all__= ['TFT']
[docs]
@register_keras_serializable('fusionlab.nn.transformers', name="TFT")
class TFT(Model, NNLearner):
[docs]
@validate_params({
"dynamic_input_dim": [Interval(Integral, 1, None, closed='left')],
"static_input_dim": [Interval(Integral, 0, None, closed='left')],
"future_input_dim": [Interval(Integral, 0, None, closed='left')],
"hidden_units": [Interval(Integral, 1, None, closed='left')],
"num_heads": [Interval(Integral, 1, None, closed='left')],
"dropout_rate": [Interval(Real, 0, 1, closed="both")],
"recurrent_dropout_rate": [Interval(Real, 0, 1, closed="both")],
"forecast_horizon": [Interval(Integral, 1, None, closed='left')],
"quantiles": ['array-like', None],
"activation": [StrOptions(
{"elu", "relu", "tanh", "sigmoid", "linear", "gelu"})],
"use_batch_norm": [bool],
"num_lstm_layers": [Interval(Integral, 1, None, closed='left')],
"lstm_units": ['array-like', Interval(Integral, 1, None, closed='left'), None],
"output_dim": [Interval(Integral, 1, None, closed='left')],
})
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
dynamic_input_dim: int,
static_input_dim: int,
future_input_dim: int,
hidden_units: int = 32,
num_heads: int = 4,
dropout_rate: float = 0.1,
recurrent_dropout_rate: float = 0.0,
forecast_horizon: int = 1,
quantiles: Optional[List[float]] = None,
activation: str = 'elu',
use_batch_norm: bool = False,
num_lstm_layers: int = 1,
lstm_units: Optional[Union[int, List[int]]] = None,
output_dim: int = 1,
**kwargs
):
super().__init__(**kwargs)
self.dynamic_input_dim = dynamic_input_dim
self.static_input_dim = static_input_dim
self.future_input_dim = future_input_dim
self.hidden_units = hidden_units
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.recurrent_dropout_rate = recurrent_dropout_rate
self.forecast_horizon = forecast_horizon
self.activation = activation
self.use_batch_norm = use_batch_norm
self.num_lstm_layers = num_lstm_layers
self.output_dim = output_dim
self.quantiles = validate_quantiles(
quantiles) if quantiles else None
self.num_quantiles = len(
self.quantiles) if self.quantiles else 1
self._lstm_units = lstm_units
# Process LSTM units list
_lstm_units_resolved = lstm_units or hidden_units
self.lstm_units_list = (
[_lstm_units_resolved] * num_lstm_layers
if isinstance(_lstm_units_resolved, int)
else is_iterable(
_lstm_units_resolved,
exclude_string=True,
transform=True
)
)
self.lstm_units_list = [
validate_positive_integer(v, "LSTM units")
for v in self.lstm_units_list
]
if len(self.lstm_units_list) != num_lstm_layers:
raise ValueError(
"'lstm_units' length must match 'num_lstm_layers'.")
# --- Initialize Core TFT Components ---
# 1. Variable Selection Networks
self.static_vsn = VariableSelectionNetwork(
num_inputs=self.static_input_dim,
units=self.hidden_units,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_vsn"
)
self.dynamic_vsn = VariableSelectionNetwork(
num_inputs=self.dynamic_input_dim,
units=self.hidden_units,
dropout_rate=self.dropout_rate,
use_time_distributed=True,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="dynamic_vsn"
)
self.future_vsn = VariableSelectionNetwork(
num_inputs=self.future_input_dim,
units=self.hidden_units,
dropout_rate=self.dropout_rate,
use_time_distributed=True,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="future_vsn"
)
# 2. Static Context GRNs
self.static_grn_for_vsns = GatedResidualNetwork(
units=self.hidden_units,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_grn_for_vsns"
)
self.static_grn_for_enrichment = GatedResidualNetwork(
units=self.hidden_units,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_grn_for_enrichment"
)
self.static_grn_for_state_h = GatedResidualNetwork(
units=self.lstm_units_list[0],
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_grn_for_state_h"
)
self.static_grn_for_state_c = GatedResidualNetwork(
units=self.lstm_units_list[0],
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_grn_for_state_c"
)
# 3. LSTM Encoder Layers
self.lstm_layers = [
LSTM(
units=units, return_sequences=True,
dropout=self.dropout_rate,
recurrent_dropout=self.recurrent_dropout_rate,
name=f'encoder_lstm_{i+1}'
) for i, units in enumerate(self.lstm_units_list)
]
# 4. Static Enrichment GRN
self.static_enrichment_grn = GatedResidualNetwork(
units=self.hidden_units,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="static_enrichment_grn"
)
# 5. Temporal Self-Attention Layer
self.temporal_attention_layer = TemporalAttentionLayer(
units=self.hidden_units,
num_heads=self.num_heads,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="temporal_self_attention"
)
# 6. Position-wise Feedforward GRN
self.positionwise_grn = GatedResidualNetwork(
units=self.hidden_units,
dropout_rate=self.dropout_rate,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
name="pos_wise_ff_grn"
)
# 7. Output Layer(s)
if self.quantiles:
self.output_layers = [
TimeDistributed(
Dense(self.output_dim), name=f'q_{int(q*100)}_td'
)
for q in self.quantiles
]
else:
self.output_layer = TimeDistributed(
Dense(self.output_dim), name='point_td'
)
# 8. Positional Encoding Layer
self.positional_encoding = PositionalEncoding(name="pos_enc")
[docs]
@tf_autograph.experimental.do_not_convert
def call(self, inputs, training=None):
"""Forward pass for the revised TFT with numerical inputs."""
logger.debug(f"TFT '{self.name}': Entering call method.")
logger.debug(f" Received {len(inputs)} inputs.")
# --- Input Validation and Reordering ---
# User provides [static, dynamic, future]
# Validator expects [dynamic, future, static]
if not isinstance(inputs, (list, tuple)) or len(inputs) != 3:
raise ValueError(
"TFT expects inputs as list/tuple of 3 elements: "
"[static_inputs, dynamic_inputs, future_inputs]."
)
static_inputs_user, dynamic_inputs_user, future_inputs_user = inputs
logger.debug(
f" User inputs shapes: Static={static_inputs_user.shape}, "
f"Dynamic={dynamic_inputs_user.shape}, "
f"Future={future_inputs_user.shape}"
)
# Reorder for internal validation function
validator_input_order = [
static_inputs_user, dynamic_inputs_user, future_inputs_user
]
# Call validator: returns (dynamic, future, static) tensors
# Performs type checks, float32 conversion, dimension checks.
static_inputs, dynamic_inputs, future_inputs = validate_model_inputs(
validator_input_order,
dynamic_input_dim=self.dynamic_input_dim,
static_input_dim=self.static_input_dim,
future_covariate_dim=self.future_input_dim,
)
logger.debug(
" Inputs validated and assigned internally."
f" Shapes: Dyn={dynamic_inputs.shape},"
f" Fut={future_inputs.shape}, Stat={static_inputs.shape}"
)
# --- Static Pathway ---
logger.debug(" Processing Static Pathway...")
# 1a. Reshape Static Input for VSN if needed (B, N) -> (B, N, 1)
# why use static_inputs.shape.rank rather than tf_rank (static_inputs)?
# to avoid issue of unknow rank for autograph conversion
# when it's removed.
if static_inputs.shape.rank == 2:
static_inputs_r = tf_expand_dims(static_inputs, axis=-1)
logger.debug(
" Expanded static input rank to 3:"
f" {static_inputs_r.shape}")
else:
static_inputs_r = static_inputs # already (B, N, F)
# 1b. Static VSN
# Processes static features, potentially
# using context if VSN modified
static_selected = self.static_vsn(
static_inputs_r,
training=training,
# context=None # Context for static VSN usually not needed
)
# Output shape: (B, hidden_units)
logger.debug(
f" Static VSN output shape: {static_selected.shape}")
# 1c. Static Context Vector Generation using GRNs
# Context for conditioning VSNs (passed if VSNs accept context)
context_for_vsns = self.static_grn_for_vsns(
static_selected, training=training)
# Context for enriching temporal features after LSTM
context_for_enrichment = self.static_grn_for_enrichment(
static_selected, training=training)
# Contexts for initializing LSTM states
context_state_h = self.static_grn_for_state_h(
static_selected, training=training)
context_state_c = self.static_grn_for_state_c(
static_selected, training=training)
initial_state = [context_state_h, context_state_c]
logger.debug(
f" Generated static contexts:"
f" VSN={context_for_vsns.shape},"
f" Enrich={context_for_enrichment.shape},"
f" StateH={context_state_h.shape},"
f" StateC={context_state_c.shape}"
)
# --- Temporal Pathway ---
logger.debug(" Processing Temporal Pathway...")
# 3a. Reshape Dynamic/Future Inputs for VSNs if needed
if dynamic_inputs.shape.rank == 3:
dynamic_inputs_r = tf_expand_dims(dynamic_inputs, axis=-1)
else:
dynamic_inputs_r = dynamic_inputs # Assume (B, T, N, F)
if future_inputs.shape.rank == 3:
future_inputs_r = tf_expand_dims(future_inputs, axis=-1)
else:
future_inputs_r = future_inputs # Assume (B, T_fut, N, F)
logger.debug(
f" Temporal input shapes for VSN: Dyn={dynamic_inputs_r.shape},"
f" Fut={future_inputs_r.shape}"
)
# 3b. Dynamic/Future VSNs
# Pass static context derived earlier
dynamic_selected = self.dynamic_vsn(
dynamic_inputs_r, training=training,
context=context_for_vsns)
future_selected = self.future_vsn(
future_inputs_r, training=training, context=context_for_vsns)
# Shapes: (B, T_past, H_units), (B, T_future_total, H_units)
logger.debug(
f" Temporal VSN outputs shapes: Dyn={dynamic_selected.shape},"
f" Fut={future_selected.shape}"
)
# 4. Combine Features for LSTM Input using helper
# Handles slicing future_selected to match T_past and concatenates
logger.debug(
" Combining dynamic and future features for LSTM...")
temporal_features = combine_temporal_inputs_for_lstm(
dynamic_selected, future_selected,
mode='soft' # Use soft? or strict?
)
# Shape: (B, T_past, combined_features = D_dyn_emb + D_fut_emb)
# Assuming VSN outputs hidden_units: (B, T_past, 2 * hidden_units)
logger.debug(
f" Combined temporal features shape:"
f" {temporal_features.shape}")
# 5. Positional Encoding
temporal_features_pos = self.positional_encoding(
temporal_features
)
logger.debug(" Applied positional encoding.")
# 6. LSTM Encoder
logger.debug(" Running LSTM encoder...")
lstm_output = temporal_features_pos
current_state = initial_state
for i, layer in enumerate(self.lstm_layers):
layer_input_shape = lstm_output.shape
if i == 0:
lstm_output = layer(
lstm_output, initial_state=current_state,
training=training
)
else:
lstm_output = layer(lstm_output, training=training)
logger.debug(
f" LSTM layer {i+1} output shape: {layer_input_shape}")
# Final LSTM output shape: (B, T_past, lstm_units)
# 7. Static Enrichment
logger.debug(" Applying static enrichment...")
enriched_output = self.static_enrichment_grn(
lstm_output, context=context_for_enrichment,
training=training
)
# Shape: (B, T_past, hidden_units)
logger.debug(
f" Enriched output shape: {enriched_output.shape}")
# 8. Temporal Self-Attention
logger.debug(" Applying temporal attention...")
attention_output = self.temporal_attention_layer(
enriched_output,
context_vector=context_for_vsns,
training=training
)
# Shape: (B, T_past, hidden_units)
logger.debug(
f" Attention output shape: {attention_output.shape}")
# 9. Position-wise Feedforward
logger.debug(" Applying position-wise feedforward...")
final_temporal_repr = self.positionwise_grn(
attention_output, training=training
)
# Shape: (B, T_past, hidden_units)
logger.debug(
" Final temporal representation shape:"
f" {final_temporal_repr.shape}")
# --- 10. Output Slice and Projection ---
logger.debug(" Generating final predictions...")
# Slice features corresponding to the forecast horizon
output_features_sliced = final_temporal_repr[
:, -self.forecast_horizon:, :]
logger.debug(
" Sliced features for output shape:"
f" {output_features_sliced.shape}")
# Shape: (B, H, hidden_units)
# Apply the final TimeDistributed output layer(s)
if self.quantiles:
quantile_outputs = []
if not hasattr(self, 'output_layers'):
raise AttributeError(
"Quantile output layers not initialized."
)
for i, layer in enumerate(self.output_layers):
out_i = layer(output_features_sliced, training=training)
quantile_outputs.append(out_i)
logger.debug(
f" Quantile output {i} shape: {out_i.shape}")
outputs = tf_stack(quantile_outputs, axis=2) # (B, H, Q, O)
logger.debug(
f" Stacked quantile output shape: {outputs.shape}")
if self.output_dim == 1:
outputs = tf_squeeze(outputs, axis=-1) # (B, H, Q)
logger.debug(
" Squeezed final dimension (output_dim=1).")
else:
# Point Forecast
if not hasattr(self, 'output_layer'):
raise AttributeError("Point output layer not initialized.")
outputs = self.output_layer(
output_features_sliced, training=training
)
# Shape (B, H, O)
logger.debug(
f"TFT '{self.name}': Final output shape: {outputs.shape}")
logger.debug(
f"TFT '{self.name}': Exiting call method.")
return outputs
[docs]
def compile(self, optimizer, loss=None, **kwargs):
if self.quantiles is None:
effective_loss = loss or 'mean_squared_error'
else:
effective_loss = loss or combined_quantile_loss(
self.quantiles)
super().compile(
optimizer=optimizer,
loss=effective_loss,
**kwargs
)
[docs]
def get_config(self):
config = super().get_config()
config.update({
'dynamic_input_dim': self.dynamic_input_dim,
'static_input_dim': self.static_input_dim,
'future_input_dim': self.future_input_dim,
'hidden_units': self.hidden_units,
'num_heads': self.num_heads,
'dropout_rate': self.dropout_rate,
'recurrent_dropout_rate': self.recurrent_dropout_rate,
'forecast_horizon': self.forecast_horizon,
'quantiles': self.quantiles,
'activation': self.activation,
'use_batch_norm': self.use_batch_norm,
'num_lstm_layers': self.num_lstm_layers,
'lstm_units': self._lstm_units,
'output_dim': self.output_dim,
})
return config
[docs]
@classmethod
def from_config(cls, config):
return cls(**config)
TFT.__doc__=r"""
Temporal Fusion Transformer (TFT) required static, dynamic(past)
and future inputs.
This class implements the Temporal Fusion Transformer (TFT)
architecture, closely following the structure described in the
original paper [Lim21]_. It is designed for multi-horizon time
series forecasting and explicitly requires static covariates,
dynamic (historical) covariates, and known future covariates as
inputs.
Compared to more flexible implementations, this version mandates
all input types, simplifying the internal input handling logic. It
incorporates key TFT components like Variable Selection Networks
(VSNs), Gated Residual Networks (GRNs) for static context generation
and feature processing, LSTM encoding, static enrichment, interpretable
multi-head attention, and position-wise feedforward layers.
Parameters
----------
dynamic_input_dim : int
The total number of features present in the dynamic (past)
input tensor. These are features that vary across the lookback
time steps.
static_input_dim : int
The total number of features present in the static
(time-invariant) input tensor. These features provide context
that does not change over time for a given series.
future_input_dim : int
The total number of features present in the known future input
tensor. These features provide information about future events
or conditions known at the time of prediction.
hidden_units : int, default=32
The main dimensionality of the hidden layers used throughout
the network, including VSN outputs, GRN hidden states,
enrichment layers, and attention mechanisms.
num_heads : int, default=4
Number of attention heads used in the
:class:`~fusionlab.nn.components.TemporalAttentionLayer`. More heads
allow attending to different representation subspaces.
dropout_rate : float, default=0.1
Dropout rate applied to non-recurrent connections in LSTMs,
VSNs, GRNs, and Attention layers. Value between 0 and 1.
recurrent_dropout_rate : float, default=0.0
Dropout rate applied specifically to the recurrent connections
within the LSTM layers. Helps regularize the recurrent state
updates. Value between 0 and 1. Note: May impact performance
on GPUs.
forecast_horizon : int, default=1
The number of future time steps the model is trained to predict
simultaneously (multi-horizon forecasting).
quantiles : list[float], optional, default=None
A list of quantiles (e.g., ``[0.1, 0.5, 0.9]``) for probabilistic
forecasting.
* If provided, the model outputs predictions for each specified
quantile, and the :func:`~fusionlab.nn.losses.combined_quantile_loss`
should typically be used for training.
* If ``None``, the model performs point forecasting (outputting a
single value per step), typically trained with MSE loss.
activation : str, default='elu'
Activation function used within GRNs and potentially other dense
layers. Supported options include 'elu', 'relu', 'gelu', 'tanh',
'sigmoid', 'linear'.
use_batch_norm : bool, default=False
If True, applies Batch Normalization within the Gated Residual
Networks (GRNs). Layer Normalization is typically used by default
within GRN implementations as per the TFT paper.
num_lstm_layers : int, default=1
Number of LSTM layers stacked in the sequence encoder module.
lstm_units : int or list[int], optional, default=None
Number of hidden units in each LSTM encoder layer.
* If ``int``: The same number of units is used for all layers
specified by `num_lstm_layers`.
* If ``list[int]``: Specifies the number of units for each LSTM
layer sequentially. The length must match `num_lstm_layers`.
* If ``None``: Defaults to using `hidden_units` for all LSTM layers.
output_dim : int, default=1
The number of target variables the model predicts at each time
step. Typically 1 for univariate forecasting.
**kwargs
Additional keyword arguments passed to the parent Keras `Model`.
Notes
-----
This implementation requires inputs to the `call` method as a list
or tuple containing exactly three tensors in the order:
``[static_inputs, dynamic_inputs, future_inputs]``. The shapes
should be:
* `static_inputs`: `(Batch, StaticFeatures)`
* `dynamic_inputs`: `(Batch, PastTimeSteps, DynamicFeatures)`
* `future_inputs`: `(Batch, TotalTimeSteps, FutureFeatures)` where
`TotalTimeSteps` includes past and future steps relevant for
the LSTM processing.
**Use Case and Importance**
This revised `TFT` class provides a structured implementation that
closely follows the component architecture described in the original
TFT paper, including the distinct GRNs for generating static context
vectors. By requiring all input types (static, dynamic past, known
future), it simplifies the input handling logic compared to versions
allowing optional inputs. This makes it a suitable choice when you
have all three types of features available and want a robust baseline
TFT implementation that explicitly leverages static context for VSNs,
LSTM initialization, and temporal processing enrichment. It serves as
a strong foundation for complex multi-horizon forecasting tasks that
benefit from diverse data integration and interpretability components
like VSNs and attention.
**Mathematical Formulation**
The model processes inputs through the following key stages:
1. **Variable Selection:** Separate Variable Selection Networks (VSNs)
are applied to the static ($\mathbf{s}$), dynamic past
($\mathbf{x}_t, t \le T$), and known future ($\mathbf{z}_t, t > T$)
inputs. This step identifies relevant features within each input type
and transforms them into embeddings of dimension `hidden_units`. Let
the outputs be $\zeta$ (static embedding), $\xi^{dyn}_t$ (dynamic),
and $\xi^{fut}_t$ (future). VSNs may be conditioned by a static
context vector $c_s$.
.. math::
\zeta &= \text{VSN}_{static}(\mathbf{s}, [c_s]) \\
\xi^{dyn}_t &= \text{VSN}_{dyn}(\mathbf{x}_t, [c_s]) \\
\xi^{fut}_t &= \text{VSN}_{fut}(\mathbf{z}_t, [c_s])
2. **Static Context Generation:** Four distinct Gated Residual
Networks (GRNs) process the static embedding $\zeta$ to produce
context vectors: $c_s$ (for VSNs), $c_e$ (for enrichment),
$c_h$ (LSTM initial hidden state), $c_c$ (LSTM initial cell state).
.. math::
c_s = GRN_{vs}(\zeta) \quad ... \quad c_c = GRN_{c}(\zeta)
3. **Temporal Processing Input:** The selected dynamic and future
embeddings are potentially combined (e.g., concatenated along time
or features, depending on preprocessing) and augmented with
positional encoding to form the input sequence for the LSTM.
Let this sequence be $\psi_t$.
.. math::
\psi_t = \text{Combine}(\xi^{dyn}_t, \xi^{fut}_t) + \text{PosEncode}(t)
4. **LSTM Encoder:** A stack of `num_lstm_layers` LSTMs processes
$\psi_t$, initialized with $[c_h, c_c]$.
.. math::
\{h_t\} = \text{LSTMStack}(\{\psi_t\}, \text{initial_state}=[c_h, c_c])
5. **Static Enrichment:** The LSTM outputs $h_t$ are combined with the
static enrichment context $c_e$ using a time-distributed GRN.
.. math::
\phi_t = GRN_{enrich}(h_t, c_e)
6. **Temporal Self-Attention:** Interpretable Multi-Head Attention is
applied to the enriched sequence $\{\phi_t\}$, potentially using
$c_s$ as context within the attention mechanism's internal GRNs.
This results in context vectors $\beta_t$ after residual connection,
gating (GLU), and normalization.
.. math::
\beta_t = \text{TemporalAttention}(\{\phi_t\}, c_s)
7. **Position-wise Feed-Forward:** A final time-distributed GRN is
applied to the attention output.
.. math::
\delta_t = GRN_{final}(\beta_t)
8. **Output Projection:** The features corresponding to the forecast
horizon ($t > T$) are selected from $\{\delta_t\}$ and passed through
a final Dense layer (or multiple layers for quantiles) to produce
the predictions $\hat{y}_{t+1}, ..., \hat{y}_{t+\tau}$.
Methods
-------
call(inputs, training=False)
Performs the forward pass. Expects `inputs` as a list/tuple:
`[static_inputs, dynamic_inputs, future_inputs]`.
compile(optimizer, loss=None, **kwargs)
Compiles the model, automatically selecting 'mse' or quantile
loss based on `quantiles` initialization if `loss` is not given.
Examples
--------
>>> import numpy as np
>>> import tensorflow as tf
>>> from fusionlab.nn.transformers import TFT
>>> from fusionlab.nn.losses import combined_quantile_loss
>>>
>>> # Dummy Data Dimensions
>>> B, T_past, H = 4, 12, 6 # Batch, Lookback, Horizon
>>> D_dyn, D_stat, D_fut = 5, 3, 2
>>> T_future = H # Assume future inputs cover horizon only for LSTM input
>>>
>>> # Create Dummy Input Tensors (Ensure correct shapes and types)
>>> static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
>>> dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
>>> # Future input needs shape (B, T_past + T_future, D_fut) for VSN
>>> # or (B, T_future, D_fut) if handled differently before LSTM concat.
>>> # Let's assume preprocessed to match horizon T_future for simplicity here
>>> future_in = tf.random.normal((B, T_future, D_fut), dtype=tf.float32)
>>>
>>> # Instantiate Model for Quantile Forecasting
>>> model = TFT(
... dynamic_input_dim=D_dyn, static_input_dim=D_stat,
... future_input_dim=D_fut, forecast_horizon=H,
... hidden_units=16, num_heads=2, num_lstm_layers=1,
... quantiles=[0.1, 0.5, 0.9], output_dim=1
... )
>>>
>>> # Compile with appropriate loss
>>> loss_fn = combined_quantile_loss([0.1, 0.5, 0.9])
>>> model.compile(optimizer='adam', loss=loss_fn)
>>>
>>> # Prepare input list in correct order: [static, dynamic, future]
>>> model_inputs = [static_in, dynamic_in, future_in]
>>>
>>> # Make a prediction (forward pass)
>>> # Note: Need to build the model first, e.g., by calling it once
>>> # or specifying input_shape in build method if using subclassing.
>>> # Alternatively, fit for one step. For direct call:
>>> # output_shape = model.compute_output_shape(
>>> # [t.shape for t in model_inputs]) # Requires TF >= 2.8 approx
>>> # For simplicity, assume model builds on first call
>>> predictions = model(model_inputs, training=False)
>>> print(f"Output shape: {predictions.shape}")
Output shape: (4, 6, 3)
See Also
--------
fusionlab.nn.components.VariableSelectionNetwork : Core component for VSN.
fusionlab.nn.components.GatedResidualNetwork : Core component for GRN.
fusionlab.nn.components.TemporalAttentionLayer : Core attention block.
tensorflow.keras.layers.LSTM : Recurrent layer used internally.
fusionlab.nn.losses.combined_quantile_loss : Default loss for quantiles.
fusionlab.nn.utils.reshape_xtft_data : Utility to prepare inputs.
fusionlab.nn.XTFT : More advanced related architecture.
tensorflow.keras.Model : Base Keras model class.
References
----------
.. [Lim21] Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021).
Temporal fusion transformers for interpretable multi-horizon
time series forecasting. *International Journal of Forecasting*,
37(4), 1748-1764.
"""