fusionlab.nn.transformers.TFT

class fusionlab.nn.transformers.TFT[source]

Bases: Model, NNLearner

Temporal Fusion Transformer (TFT) required static, dynamic(past) and future inputs.

This class implements the Temporal Fusion Transformer (TFT) architecture, closely following the structure described in the original paper [Lim21]. It is designed for multi-horizon time series forecasting and explicitly requires static covariates, dynamic (historical) covariates, and known future covariates as inputs.

Compared to more flexible implementations, this version mandates all input types, simplifying the internal input handling logic. It incorporates key TFT components like Variable Selection Networks (VSNs), Gated Residual Networks (GRNs) for static context generation and feature processing, LSTM encoding, static enrichment, interpretable multi-head attention, and position-wise feedforward layers.

Parameters:
  • dynamic_input_dim (int) – The total number of features present in the dynamic (past) input tensor. These are features that vary across the lookback time steps.

  • static_input_dim (int) – The total number of features present in the static (time-invariant) input tensor. These features provide context that does not change over time for a given series.

  • future_input_dim (int) – The total number of features present in the known future input tensor. These features provide information about future events or conditions known at the time of prediction.

  • hidden_units (int, default 32) – The main dimensionality of the hidden layers used throughout the network, including VSN outputs, GRN hidden states, enrichment layers, and attention mechanisms.

  • num_heads (int, default 4) – Number of attention heads used in the TemporalAttentionLayer. More heads allow attending to different representation subspaces.

  • dropout_rate (float, default 0.1) – Dropout rate applied to non-recurrent connections in LSTMs, VSNs, GRNs, and Attention layers. Value between 0 and 1.

  • recurrent_dropout_rate (float, default 0.0) – Dropout rate applied specifically to the recurrent connections within the LSTM layers. Helps regularize the recurrent state updates. Value between 0 and 1. Note: May impact performance on GPUs.

  • forecast_horizon (int, default 1) – The number of future time steps the model is trained to predict simultaneously (multi-horizon forecasting).

  • quantiles (list[float], optional, default None) –

    A list of quantiles (e.g., [0.1, 0.5, 0.9]) for probabilistic forecasting. * If provided, the model outputs predictions for each specified

    quantile, and the combined_quantile_loss() should typically be used for training.

    • If None, the model performs point forecasting (outputting a

      single value per step), typically trained with MSE loss.

  • activation (str, default 'elu') – Activation function used within GRNs and potentially other dense layers. Supported options include ‘elu’, ‘relu’, ‘gelu’, ‘tanh’, ‘sigmoid’, ‘linear’.

  • use_batch_norm (bool, default False) – If True, applies Batch Normalization within the Gated Residual Networks (GRNs). Layer Normalization is typically used by default within GRN implementations as per the TFT paper.

  • num_lstm_layers (int, default 1) – Number of LSTM layers stacked in the sequence encoder module.

  • lstm_units (int or list[int], optional, default None) –

    Number of hidden units in each LSTM encoder layer. * If int: The same number of units is used for all layers

    specified by num_lstm_layers.

    • If list[int]: Specifies the number of units for each LSTM

      layer sequentially. The length must match num_lstm_layers.

    • If None: Defaults to using hidden_units for all LSTM layers.

  • output_dim (int, default 1) – The number of target variables the model predicts at each time step. Typically 1 for univariate forecasting.

  • **kwargs – Additional keyword arguments passed to the parent Keras Model.

Notes

This implementation requires inputs to the call method as a list or tuple containing exactly three tensors in the order: [static_inputs, dynamic_inputs, future_inputs]. The shapes should be: * static_inputs: (Batch, StaticFeatures) * dynamic_inputs: (Batch, PastTimeSteps, DynamicFeatures) * future_inputs: (Batch, TotalTimeSteps, FutureFeatures) where

TotalTimeSteps includes past and future steps relevant for the LSTM processing.

Use Case and Importance

This revised TFT class provides a structured implementation that closely follows the component architecture described in the original TFT paper, including the distinct GRNs for generating static context vectors. By requiring all input types (static, dynamic past, known future), it simplifies the input handling logic compared to versions allowing optional inputs. This makes it a suitable choice when you have all three types of features available and want a robust baseline TFT implementation that explicitly leverages static context for VSNs, LSTM initialization, and temporal processing enrichment. It serves as a strong foundation for complex multi-horizon forecasting tasks that benefit from diverse data integration and interpretability components like VSNs and attention.

Mathematical Formulation

The model processes inputs through the following key stages:

  1. Variable Selection: Separate Variable Selection Networks (VSNs) are applied to the static ($mathbf{s}$), dynamic past ($mathbf{x}_t, t le T$), and known future ($mathbf{z}_t, t > T$) inputs. This step identifies relevant features within each input type and transforms them into embeddings of dimension hidden_units. Let the outputs be $zeta$ (static embedding), $xi^{dyn}_t$ (dynamic), and $xi^{fut}_t$ (future). VSNs may be conditioned by a static context vector $c_s$.

    \[\begin{split}\zeta &= \text{VSN}_{static}(\mathbf{s}, [c_s]) \\ \xi^{dyn}_t &= \text{VSN}_{dyn}(\mathbf{x}_t, [c_s]) \\ \xi^{fut}_t &= \text{VSN}_{fut}(\mathbf{z}_t, [c_s])\end{split}\]
  2. Static Context Generation: Four distinct Gated Residual Networks (GRNs) process the static embedding $zeta$ to produce context vectors: $c_s$ (for VSNs), $c_e$ (for enrichment), $c_h$ (LSTM initial hidden state), $c_c$ (LSTM initial cell state).

    \[c_s = GRN_{vs}(\zeta) \quad ... \quad c_c = GRN_{c}(\zeta)\]
  3. Temporal Processing Input: The selected dynamic and future embeddings are potentially combined (e.g., concatenated along time or features, depending on preprocessing) and augmented with positional encoding to form the input sequence for the LSTM. Let this sequence be $psi_t$.

    \[\psi_t = \text{Combine}(\xi^{dyn}_t, \xi^{fut}_t) + \text{PosEncode}(t)\]
  4. LSTM Encoder: A stack of num_lstm_layers LSTMs processes $psi_t$, initialized with $[c_h, c_c]$.

    \[\{h_t\} = \text{LSTMStack}(\{\psi_t\}, \text{initial_state}=[c_h, c_c])\]
  5. Static Enrichment: The LSTM outputs $h_t$ are combined with the static enrichment context $c_e$ using a time-distributed GRN.

    \[\phi_t = GRN_{enrich}(h_t, c_e)\]
  6. Temporal Self-Attention: Interpretable Multi-Head Attention is applied to the enriched sequence ${phi_t}$, potentially using $c_s$ as context within the attention mechanism’s internal GRNs. This results in context vectors $beta_t$ after residual connection, gating (GLU), and normalization.

    \[\beta_t = \text{TemporalAttention}(\{\phi_t\}, c_s)\]
  7. Position-wise Feed-Forward: A final time-distributed GRN is applied to the attention output.

    \[\delta_t = GRN_{final}(\beta_t)\]
  8. Output Projection: The features corresponding to the forecast horizon ($t > T$) are selected from ${delta_t}$ and passed through a final Dense layer (or multiple layers for quantiles) to produce the predictions $hat{y}_{t+1}, …, hat{y}_{t+tau}$.

call(inputs, training=False)[source]

Performs the forward pass. Expects inputs as a list/tuple: [static_inputs, dynamic_inputs, future_inputs].

compile(optimizer, loss=None, \*\*kwargs)[source]

Compiles the model, automatically selecting ‘mse’ or quantile loss based on quantiles initialization if loss is not given.

Examples

>>> import numpy as np
>>> import tensorflow as tf
>>> from fusionlab.nn.transformers import TFT
>>> from fusionlab.nn.losses import combined_quantile_loss
>>>
>>> # Dummy Data Dimensions
>>> B, T_past, H = 4, 12, 6 # Batch, Lookback, Horizon
>>> D_dyn, D_stat, D_fut = 5, 3, 2
>>> T_future = H # Assume future inputs cover horizon only for LSTM input
>>>
>>> # Create Dummy Input Tensors (Ensure correct shapes and types)
>>> static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
>>> dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
>>> # Future input needs shape (B, T_past + T_future, D_fut) for VSN
>>> # or (B, T_future, D_fut) if handled differently before LSTM concat.
>>> # Let's assume preprocessed to match horizon T_future for simplicity here
>>> future_in = tf.random.normal((B, T_future, D_fut), dtype=tf.float32)
>>>
>>> # Instantiate Model for Quantile Forecasting
>>> model = TFT(
...     dynamic_input_dim=D_dyn, static_input_dim=D_stat,
...     future_input_dim=D_fut, forecast_horizon=H,
...     hidden_units=16, num_heads=2, num_lstm_layers=1,
...     quantiles=[0.1, 0.5, 0.9], output_dim=1
... )
>>>
>>> # Compile with appropriate loss
>>> loss_fn = combined_quantile_loss([0.1, 0.5, 0.9])
>>> model.compile(optimizer='adam', loss=loss_fn)
>>>
>>> # Prepare input list in correct order: [static, dynamic, future]
>>> model_inputs = [static_in, dynamic_in, future_in]
>>>
>>> # Make a prediction (forward pass)
>>> # Note: Need to build the model first, e.g., by calling it once
>>> # or specifying input_shape in build method if using subclassing.
>>> # Alternatively, fit for one step. For direct call:
>>> # output_shape = model.compute_output_shape(
>>> #    [t.shape for t in model_inputs]) # Requires TF >= 2.8 approx
>>> # For simplicity, assume model builds on first call
>>> predictions = model(model_inputs, training=False)
>>> print(f"Output shape: {predictions.shape}")
Output shape: (4, 6, 3)

See also

fusionlab.nn.components.VariableSelectionNetwork

Core component for VSN.

fusionlab.nn.components.GatedResidualNetwork

Core component for GRN.

fusionlab.nn.components.TemporalAttentionLayer

Core attention block.

tensorflow.keras.layers.LSTM

Recurrent layer used internally.

fusionlab.nn.losses.combined_quantile_loss

Default loss for quantiles.

fusionlab.nn.utils.reshape_xtft_data

Utility to prepare inputs.

fusionlab.nn.XTFT

More advanced related architecture.

tensorflow.keras.Model

Base Keras model class.

References

[Lim21]

Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748-1764.

__init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]
Parameters:
  • dynamic_input_dim (int)

  • static_input_dim (int)

  • future_input_dim (int)

  • hidden_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • recurrent_dropout_rate (float)

  • forecast_horizon (int)

  • quantiles (List[float] | None)

  • activation (str)

  • use_batch_norm (bool)

  • num_lstm_layers (int)

  • lstm_units (int | List[int] | None)

  • output_dim (int)

Methods

__init__(dynamic_input_dim, ...[, ...])

add_loss(loss)

Can be called inside of the call() method to add a scalar loss.

add_metric(*args, **kwargs)

add_variable(shape, initializer[, dtype, ...])

Add a weight variable to the layer.

add_weight([shape, initializer, dtype, ...])

Add a weight variable to the layer.

build(input_shape)

build_from_config(config)

Builds the layer's states with the supplied config dict.

call(inputs[, training])

Forward pass for the revised TFT with numerical inputs.

compile(optimizer[, loss])

Configures the model for training.

compile_from_config(config)

Compiles the model with the information given in config.

compiled_loss(y, y_pred[, sample_weight, ...])

compute_loss([x, y, y_pred, sample_weight, ...])

Compute the total loss, validate it, and return it.

compute_mask(inputs, previous_mask)

compute_metrics(x, y, y_pred[, sample_weight])

Update metric states and collect all metrics to be returned.

compute_output_shape(*args, **kwargs)

compute_output_spec(*args, **kwargs)

count_params()

Count the total number of scalars composing the weights.

evaluate([x, y, batch_size, verbose, ...])

Returns the loss value & metrics values for the model in test mode.

export(filepath[, format, verbose, ...])

Export the model as an artifact for inference.

fit([x, y, batch_size, epochs, verbose, ...])

Trains the model for a fixed number of epochs (dataset iterations).

from_config(config)

Creates an operation from its config.

get_build_config()

Returns a dictionary with the layer's input shape.

get_compile_config()

Returns a serialized config with information for compiling the model.

get_config()

Returns the config of the object.

get_layer([name, index])

Retrieves a layer based on either its name (unique) or index.

get_metrics_result()

Returns the model's metrics values as a dict.

get_params([deep])

Get the parameters for this learner.

get_state_tree([value_format])

Retrieves tree-like structure of model variables.

get_weights()

Return the values of layer.weights as a list of NumPy arrays.

help(**kwargs)

load(file_path[, format])

Load the learner's state from a specified file in the desired format.

load_own_variables(store)

Loads the state of the layer.

load_weights(filepath[, skip_mismatch])

Load the weights from a single file or sharded files.

loss(y, y_pred[, sample_weight])

make_predict_function([force])

make_test_function([force])

make_train_function([force])

predict(x[, batch_size, verbose, steps, ...])

Generates output predictions for the input samples.

predict_on_batch(x)

Returns predictions for a single batch of samples.

predict_step(data)

quantize(mode[, config])

Quantize the weights of the model.

quantized_build(input_shape, mode)

quantized_call(*args, **kwargs)

rematerialized_call(layer_call, *args, **kwargs)

Enable rematerialization dynamically for layer's call method.

reset_metrics()

save(filepath[, overwrite, zipped])

Saves a model as a .keras file.

save_own_variables(store)

Saves the state of the layer.

save_weights(filepath[, overwrite, ...])

Saves all weights to a single file or sharded files.

set_params(**params)

Set the parameters of this learner.

set_state_tree(state_tree)

Assigns values to variables of the model.

set_weights(weights)

Sets the values of layer.weights from a list of NumPy arrays.

stateless_call(trainable_variables, ...[, ...])

Call the layer without any side effects.

stateless_compute_loss(trainable_variables, ...)

summary([line_length, positions, print_fn, ...])

Prints a string summary of the network.

symbolic_call(*args, **kwargs)

test_on_batch(x[, y, sample_weight, return_dict])

Test the model on a single batch of samples.

test_step(data)

to_json(**kwargs)

Returns a JSON string containing the network configuration.

train_on_batch(x[, y, sample_weight, ...])

Runs a single gradient update on a single batch of data.

train_step(data)

Attributes

compiled_metrics

compute_dtype

The dtype of the computations performed by the layer.

distribute_reduction_method

distribute_strategy

dtype

Alias of layer.variable_dtype.

dtype_policy

input

Retrieves the input tensor(s) of a symbolic operation.

input_dtype

The dtype layer inputs should be converted to.

input_spec

jit_compile

layers

losses

List of scalar losses from add_loss, regularizers and sublayers.

metrics

List of all metrics.

metrics_names

metrics_variables

List of all metric variables.

my_params

non_trainable_variables

List of all non-trainable layer state.

non_trainable_weights

List of all non-trainable weight variables of the layer.

output

Retrieves the output tensor(s) of a layer.

path

The path of the layer.

quantization_mode

The quantization mode of this layer, None if not quantized.

run_eagerly

supports_masking

Whether this layer supports computing a mask using compute_mask.

trainable

Settable boolean, whether this layer should be trainable or not.

trainable_variables

List of all trainable layer state.

trainable_weights

List of all trainable weight variables of the layer.

variable_dtype

The dtype of the state (weights) of the layer.

variables

List of all layer state, including random seeds.

weights

List of all weight variables of the layer.

__init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]
Parameters:
  • dynamic_input_dim (int)

  • static_input_dim (int)

  • future_input_dim (int)

  • hidden_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • recurrent_dropout_rate (float)

  • forecast_horizon (int)

  • quantiles (List[float] | None)

  • activation (str)

  • use_batch_norm (bool)

  • num_lstm_layers (int)

  • lstm_units (int | List[int] | None)

  • output_dim (int)

call(inputs, training=None)[source]

Forward pass for the revised TFT with numerical inputs.

compile(optimizer, loss=None, **kwargs)[source]

Configures the model for training.

Example:

```python model.compile(

optimizer=keras.optimizers.Adam(learning_rate=1e-3), loss=keras.losses.BinaryCrossentropy(), metrics=[

keras.metrics.BinaryAccuracy(), keras.metrics.FalseNegatives(),

],

)

param optimizer:

String (name of optimizer) or optimizer instance. See keras.optimizers.

param loss:

Loss function. May be a string (name of loss function), or a keras.losses.Loss instance. See keras.losses. A loss function is any callable with the signature loss = fn(y_true, y_pred), where y_true are the ground truth values, and y_pred are the model’s predictions. y_true should have shape (batch_size, d0, .. dN) (except in the case of sparse loss functions such as sparse categorical crossentropy which expects integer arrays of shape (batch_size, d0, .. dN-1)). y_pred should have shape (batch_size, d0, .. dN). The loss function should return a float tensor.

param loss_weights:

Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model’s outputs. If a dict, it is expected to map output names (strings) to scalar coefficients.

param metrics:

List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a keras.metrics.Metric instance. See keras.metrics. Typically you will use metrics=[‘accuracy’]. A function is any callable with the signature result = fn(y_true, _pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={‘a’:’accuracy’, ‘b’:[‘accuracy’, ‘mse’]}. You can also pass a list to specify a metric or a list of metrics for each output, such as metrics=[[‘accuracy’], [‘accuracy’, ‘mse’]] or metrics=[‘accuracy’, [‘accuracy’, ‘mse’]]. When you pass the strings ‘accuracy’ or ‘acc’, we convert this to one of keras.metrics.BinaryAccuracy, keras.metrics.CategoricalAccuracy, keras.metrics.SparseCategoricalAccuracy based on the shapes of the targets and of the model output. A similar conversion is done for the strings “crossentropy” and “ce” as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the weighted_metrics argument instead.

param weighted_metrics:

List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing.

param run_eagerly:

Bool. If True, this model’s forward pass will never be compiled. It is recommended to leave this as False when training (for best performance), and to set it to True when debugging.

param steps_per_execution:

Int. The number of batches to run during each a single compiled function call. Running multiple batches inside a single compiled function call can greatly improve performance on TPUs or small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if steps_per_execution is set to N, Callback.on_batch_begin and Callback.on_batch_end methods will only be called every N batches (i.e. before/after each compiled function execution). Not supported with the PyTorch backend.

param jit_compile:

Bool or “auto”. Whether to use XLA compilation when compiling a model. For jax and tensorflow backends, jit_compile=”auto” enables XLA compilation if the model supports it, and disabled otherwise. For torch backend, “auto” will default to eager execution and jit_compile=True will run with torch.compile with the “inductor” backend.

param auto_scale_loss:

Bool. If True and the model dtype policy is “mixed_float16”, the passed optimizer will be automatically wrapped in a LossScaleOptimizer, which will dynamically scale the loss to prevent underflow.

get_config()[source]

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

classmethod from_config(config)[source]

Creates an operation from its config.

This method is the reverse of get_config, capable of instantiating the same operation from the config dictionary.

Note: If you override this method, you might receive a serialized dtype config, which is a dict. You can deserialize it as follows:

```python if “dtype” in config and isinstance(config[“dtype”], dict):

policy = dtype_policies.deserialize(config[“dtype”])

```

Parameters:

config – A Python dictionary, typically the output of get_config.

Returns:

An operation instance.

help(**kwargs)
my_params = TFT(     dynamic_input_dim,     static_input_dim,     future_input_dim,     hidden_units=32,     num_heads=4,     dropout_rate=0.1,     recurrent_dropout_rate=0.0,     forecast_horizon=1,     quantiles=None,     activation='elu',     use_batch_norm=False,     num_lstm_layers=1,     lstm_units=None,     output_dim=1 )