fusionlab.nn.transformers.TFT¶
- class fusionlab.nn.transformers.TFT[source]¶
Bases:
Model,NNLearnerTemporal Fusion Transformer (TFT) required static, dynamic(past) and future inputs.
This class implements the Temporal Fusion Transformer (TFT) architecture, closely following the structure described in the original paper [Lim21]. It is designed for multi-horizon time series forecasting and explicitly requires static covariates, dynamic (historical) covariates, and known future covariates as inputs.
Compared to more flexible implementations, this version mandates all input types, simplifying the internal input handling logic. It incorporates key TFT components like Variable Selection Networks (VSNs), Gated Residual Networks (GRNs) for static context generation and feature processing, LSTM encoding, static enrichment, interpretable multi-head attention, and position-wise feedforward layers.
- Parameters:
dynamic_input_dim (
int) – The total number of features present in the dynamic (past) input tensor. These are features that vary across the lookback time steps.static_input_dim (
int) – The total number of features present in the static (time-invariant) input tensor. These features provide context that does not change over time for a given series.future_input_dim (
int) – The total number of features present in the known future input tensor. These features provide information about future events or conditions known at the time of prediction.hidden_units (
int, default32) – The main dimensionality of the hidden layers used throughout the network, including VSN outputs, GRN hidden states, enrichment layers, and attention mechanisms.num_heads (
int, default4) – Number of attention heads used in theTemporalAttentionLayer. More heads allow attending to different representation subspaces.dropout_rate (
float, default0.1) – Dropout rate applied to non-recurrent connections in LSTMs, VSNs, GRNs, and Attention layers. Value between 0 and 1.recurrent_dropout_rate (
float, default0.0) – Dropout rate applied specifically to the recurrent connections within the LSTM layers. Helps regularize the recurrent state updates. Value between 0 and 1. Note: May impact performance on GPUs.forecast_horizon (
int, default1) – The number of future time steps the model is trained to predict simultaneously (multi-horizon forecasting).quantiles (
list[float], optional, defaultNone) –A list of quantiles (e.g.,
[0.1, 0.5, 0.9]) for probabilistic forecasting. * If provided, the model outputs predictions for each specifiedquantile, and the
combined_quantile_loss()should typically be used for training.- If
None, the model performs point forecasting (outputting a single value per step), typically trained with MSE loss.
- If
activation (
str, default'elu') – Activation function used within GRNs and potentially other dense layers. Supported options include ‘elu’, ‘relu’, ‘gelu’, ‘tanh’, ‘sigmoid’, ‘linear’.use_batch_norm (
bool, defaultFalse) – If True, applies Batch Normalization within the Gated Residual Networks (GRNs). Layer Normalization is typically used by default within GRN implementations as per the TFT paper.num_lstm_layers (
int, default1) – Number of LSTM layers stacked in the sequence encoder module.lstm_units (
intorlist[int], optional, defaultNone) –Number of hidden units in each LSTM encoder layer. * If
int: The same number of units is used for all layersspecified by num_lstm_layers.
- If
list[int]: Specifies the number of units for each LSTM layer sequentially. The length must match num_lstm_layers.
- If
If
None: Defaults to using hidden_units for all LSTM layers.
output_dim (
int, default1) – The number of target variables the model predicts at each time step. Typically 1 for univariate forecasting.**kwargs – Additional keyword arguments passed to the parent Keras Model.
Notes
This implementation requires inputs to the call method as a list or tuple containing exactly three tensors in the order:
[static_inputs, dynamic_inputs, future_inputs]. The shapes should be: * static_inputs: (Batch, StaticFeatures) * dynamic_inputs: (Batch, PastTimeSteps, DynamicFeatures) * future_inputs: (Batch, TotalTimeSteps, FutureFeatures) whereTotalTimeSteps includes past and future steps relevant for the LSTM processing.
Use Case and Importance
This revised TFT class provides a structured implementation that closely follows the component architecture described in the original TFT paper, including the distinct GRNs for generating static context vectors. By requiring all input types (static, dynamic past, known future), it simplifies the input handling logic compared to versions allowing optional inputs. This makes it a suitable choice when you have all three types of features available and want a robust baseline TFT implementation that explicitly leverages static context for VSNs, LSTM initialization, and temporal processing enrichment. It serves as a strong foundation for complex multi-horizon forecasting tasks that benefit from diverse data integration and interpretability components like VSNs and attention.
Mathematical Formulation
The model processes inputs through the following key stages:
Variable Selection: Separate Variable Selection Networks (VSNs) are applied to the static ($mathbf{s}$), dynamic past ($mathbf{x}_t, t le T$), and known future ($mathbf{z}_t, t > T$) inputs. This step identifies relevant features within each input type and transforms them into embeddings of dimension hidden_units. Let the outputs be $zeta$ (static embedding), $xi^{dyn}_t$ (dynamic), and $xi^{fut}_t$ (future). VSNs may be conditioned by a static context vector $c_s$.
\[\begin{split}\zeta &= \text{VSN}_{static}(\mathbf{s}, [c_s]) \\ \xi^{dyn}_t &= \text{VSN}_{dyn}(\mathbf{x}_t, [c_s]) \\ \xi^{fut}_t &= \text{VSN}_{fut}(\mathbf{z}_t, [c_s])\end{split}\]Static Context Generation: Four distinct Gated Residual Networks (GRNs) process the static embedding $zeta$ to produce context vectors: $c_s$ (for VSNs), $c_e$ (for enrichment), $c_h$ (LSTM initial hidden state), $c_c$ (LSTM initial cell state).
\[c_s = GRN_{vs}(\zeta) \quad ... \quad c_c = GRN_{c}(\zeta)\]Temporal Processing Input: The selected dynamic and future embeddings are potentially combined (e.g., concatenated along time or features, depending on preprocessing) and augmented with positional encoding to form the input sequence for the LSTM. Let this sequence be $psi_t$.
\[\psi_t = \text{Combine}(\xi^{dyn}_t, \xi^{fut}_t) + \text{PosEncode}(t)\]LSTM Encoder: A stack of num_lstm_layers LSTMs processes $psi_t$, initialized with $[c_h, c_c]$.
\[\{h_t\} = \text{LSTMStack}(\{\psi_t\}, \text{initial_state}=[c_h, c_c])\]Static Enrichment: The LSTM outputs $h_t$ are combined with the static enrichment context $c_e$ using a time-distributed GRN.
\[\phi_t = GRN_{enrich}(h_t, c_e)\]Temporal Self-Attention: Interpretable Multi-Head Attention is applied to the enriched sequence ${phi_t}$, potentially using $c_s$ as context within the attention mechanism’s internal GRNs. This results in context vectors $beta_t$ after residual connection, gating (GLU), and normalization.
\[\beta_t = \text{TemporalAttention}(\{\phi_t\}, c_s)\]Position-wise Feed-Forward: A final time-distributed GRN is applied to the attention output.
\[\delta_t = GRN_{final}(\beta_t)\]Output Projection: The features corresponding to the forecast horizon ($t > T$) are selected from ${delta_t}$ and passed through a final Dense layer (or multiple layers for quantiles) to produce the predictions $hat{y}_{t+1}, …, hat{y}_{t+tau}$.
- call(inputs, training=False)[source]¶
Performs the forward pass. Expects inputs as a list/tuple: [static_inputs, dynamic_inputs, future_inputs].
- compile(optimizer, loss=None, \*\*kwargs)[source]¶
Compiles the model, automatically selecting ‘mse’ or quantile loss based on quantiles initialization if loss is not given.
Examples
>>> import numpy as np >>> import tensorflow as tf >>> from fusionlab.nn.transformers import TFT >>> from fusionlab.nn.losses import combined_quantile_loss >>> >>> # Dummy Data Dimensions >>> B, T_past, H = 4, 12, 6 # Batch, Lookback, Horizon >>> D_dyn, D_stat, D_fut = 5, 3, 2 >>> T_future = H # Assume future inputs cover horizon only for LSTM input >>> >>> # Create Dummy Input Tensors (Ensure correct shapes and types) >>> static_in = tf.random.normal((B, D_stat), dtype=tf.float32) >>> dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32) >>> # Future input needs shape (B, T_past + T_future, D_fut) for VSN >>> # or (B, T_future, D_fut) if handled differently before LSTM concat. >>> # Let's assume preprocessed to match horizon T_future for simplicity here >>> future_in = tf.random.normal((B, T_future, D_fut), dtype=tf.float32) >>> >>> # Instantiate Model for Quantile Forecasting >>> model = TFT( ... dynamic_input_dim=D_dyn, static_input_dim=D_stat, ... future_input_dim=D_fut, forecast_horizon=H, ... hidden_units=16, num_heads=2, num_lstm_layers=1, ... quantiles=[0.1, 0.5, 0.9], output_dim=1 ... ) >>> >>> # Compile with appropriate loss >>> loss_fn = combined_quantile_loss([0.1, 0.5, 0.9]) >>> model.compile(optimizer='adam', loss=loss_fn) >>> >>> # Prepare input list in correct order: [static, dynamic, future] >>> model_inputs = [static_in, dynamic_in, future_in] >>> >>> # Make a prediction (forward pass) >>> # Note: Need to build the model first, e.g., by calling it once >>> # or specifying input_shape in build method if using subclassing. >>> # Alternatively, fit for one step. For direct call: >>> # output_shape = model.compute_output_shape( >>> # [t.shape for t in model_inputs]) # Requires TF >= 2.8 approx >>> # For simplicity, assume model builds on first call >>> predictions = model(model_inputs, training=False) >>> print(f"Output shape: {predictions.shape}") Output shape: (4, 6, 3)
See also
fusionlab.nn.components.VariableSelectionNetworkCore component for VSN.
fusionlab.nn.components.GatedResidualNetworkCore component for GRN.
fusionlab.nn.components.TemporalAttentionLayerCore attention block.
tensorflow.keras.layers.LSTMRecurrent layer used internally.
fusionlab.nn.losses.combined_quantile_lossDefault loss for quantiles.
fusionlab.nn.utils.reshape_xtft_dataUtility to prepare inputs.
fusionlab.nn.XTFTMore advanced related architecture.
tensorflow.keras.ModelBase Keras model class.
References
[Lim21]Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748-1764.
- __init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]¶
- Parameters:
dynamic_input_dim (int)
static_input_dim (int)
future_input_dim (int)
hidden_units (int)
num_heads (int)
dropout_rate (float)
recurrent_dropout_rate (float)
forecast_horizon (int)
quantiles (List[float] | None)
activation (str)
use_batch_norm (bool)
num_lstm_layers (int)
lstm_units (int | List[int] | None)
output_dim (int)
Methods
__init__(dynamic_input_dim, ...[, ...])add_loss(losses, **kwargs)Add loss tensor(s), potentially dependent on layer inputs.
add_metric(value[, name])Adds metric tensor to the layer.
add_update(updates)Add update op(s), potentially dependent on layer inputs.
add_variable(*args, **kwargs)Deprecated, do NOT use! Alias for add_weight.
add_weight([name, shape, dtype, ...])Adds a new variable to the layer.
build(input_shape)Builds the model based on input shapes received.
build_from_config(config)Builds the layer's states with the supplied config dict.
call(inputs[, training])Forward pass for the revised TFT with numerical inputs.
compile(optimizer[, loss])Configures the model for training.
compile_from_config(config)Compiles the model with the information given in config.
compute_loss([x, y, y_pred, sample_weight])Compute the total loss, validate it, and return it.
compute_mask(inputs[, mask])Computes an output mask tensor.
compute_metrics(x, y, y_pred, sample_weight)Update metric states and collect all metrics to be returned.
compute_output_shape(input_shape)Computes the output shape of the layer.
compute_output_signature(input_signature)Compute the output tensor signature of the layer based on the inputs.
count_params()Count the total number of scalars composing the weights.
evaluate([x, y, batch_size, verbose, ...])Returns the loss value & metrics values for the model in test mode.
evaluate_generator(generator[, steps, ...])Evaluates the model on a data generator.
export(filepath)Create a SavedModel artifact for inference (e.g. via TF-Serving).
finalize_state()Finalizes the layers state after updating layer weights.
fit([x, y, batch_size, epochs, verbose, ...])Trains the model for a fixed number of epochs (dataset iterations).
fit_generator(generator[, steps_per_epoch, ...])Fits the model on data yielded batch-by-batch by a Python generator.
from_config(config)Creates a layer from its config.
get_build_config()Returns a dictionary with the layer's input shape.
get_compile_config()Returns a serialized config with information for compiling the model.
Returns the config of the Model.
get_input_at(node_index)Retrieves the input tensor(s) of a layer at a given node.
get_input_mask_at(node_index)Retrieves the input mask tensor(s) of a layer at a given node.
get_input_shape_at(node_index)Retrieves the input shape(s) of a layer at a given node.
get_layer([name, index])Retrieves a layer based on either its name (unique) or index.
get_metrics_result()Returns the model's metrics values as a dict.
get_output_at(node_index)Retrieves the output tensor(s) of a layer at a given node.
get_output_mask_at(node_index)Retrieves the output mask tensor(s) of a layer at a given node.
get_output_shape_at(node_index)Retrieves the output shape(s) of a layer at a given node.
get_params([deep])Get the parameters for this learner.
get_weight_paths()Retrieve all the variables and their paths for the model.
get_weights()Retrieves the weights of the model.
help(**kwargs)load(file_path[, format])Load the learner's state from a specified file in the desired format.
load_own_variables(store)Loads the state of the layer.
load_weights(filepath[, skip_mismatch, ...])Loads all layer weights from a saved files.
make_predict_function([force])Creates a function that executes one step of inference.
make_test_function([force])Creates a function that executes one step of evaluation.
make_train_function([force])Creates a function that executes one step of training.
predict(x[, batch_size, verbose, steps, ...])Generates output predictions for the input samples.
predict_generator(generator[, steps, ...])Generates predictions for the input samples from a data generator.
predict_on_batch(x)Returns predictions for a single batch of samples.
predict_step(data)The logic for one inference step.
reset_metrics()Resets the state of all the metrics in the model.
reset_states()save(filepath[, overwrite, save_format])Saves a model as a TensorFlow SavedModel or HDF5 file.
save_own_variables(store)Saves the state of the layer.
save_spec([dynamic_batch])Returns the tf.TensorSpec of call args as a tuple (args, kwargs).
save_weights(filepath[, overwrite, ...])Saves all layer weights.
set_params(**params)Set the parameters of this learner.
set_weights(weights)Sets the weights of the layer, from NumPy arrays.
summary([line_length, positions, print_fn, ...])Prints a string summary of the network.
test_on_batch(x[, y, sample_weight, ...])Test the model on a single batch of samples.
test_step(data)The logic for one evaluation step.
to_json(**kwargs)Returns a JSON string containing the network configuration.
to_yaml(**kwargs)Returns a yaml string containing the network configuration.
train_on_batch(x[, y, sample_weight, ...])Runs a single gradient update on a single batch of data.
train_step(data)The logic for one training step.
with_name_scope(method)Decorator to automatically enter the module name scope.
Attributes
activity_regularizerOptional regularizer function for the output of this layer.
autotune_steps_per_executionSettable property to enable tuning for steps_per_execution
compute_dtypeThe dtype of the layer's computations.
distribute_reduction_methodThe method employed to reduce per-replica values during training.
distribute_strategyThe tf.distribute.Strategy this model was created under.
dtypeThe dtype of the layer weights.
dtype_policyThe dtype policy associated with this layer.
dynamicWhether the layer is dynamic (eager-only); set in the constructor.
inbound_nodesReturn Functional API nodes upstream of this layer.
inputRetrieves the input tensor(s) of a layer.
input_maskRetrieves the input mask tensor(s) of a layer.
input_shapeRetrieves the input shape(s) of a layer.
input_specInputSpec instance(s) describing the input format for this layer.
jit_compileSpecify whether to compile the model with XLA.
layerslossesList of losses added using the add_loss() API.
metricsReturn metrics added using compile() or add_metric().
metrics_namesReturns the model's display labels for all outputs.
nameName of the layer (string), set in the constructor.
name_scopeReturns a tf.name_scope instance for this class.
non_trainable_variablesSequence of non-trainable variables owned by this module and its submodules.
non_trainable_weightsList of all non-trainable weights tracked by this layer.
outbound_nodesReturn Functional API nodes downstream of this layer.
outputRetrieves the output tensor(s) of a layer.
output_maskRetrieves the output mask tensor(s) of a layer.
output_shapeRetrieves the output shape(s) of a layer.
run_eagerlySettable attribute indicating whether the model should run eagerly.
state_updatesDeprecated, do NOT use!
statefulsteps_per_executionSettable `steps_per_execution variable. Requires a compiled model.
submodulesSequence of all sub-modules.
supports_maskingWhether this layer supports computing a mask using compute_mask.
trainabletrainable_variablesSequence of trainable variables owned by this module and its submodules.
trainable_weightsList of all trainable weights tracked by this layer.
updatesvariable_dtypeAlias of Layer.dtype, the dtype of the weights.
variablesReturns the list of all layer variables/weights.
weightsReturns the list of all layer variables/weights.
- __init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]¶
- Parameters:
dynamic_input_dim (int)
static_input_dim (int)
future_input_dim (int)
hidden_units (int)
num_heads (int)
dropout_rate (float)
recurrent_dropout_rate (float)
forecast_horizon (int)
quantiles (List[float] | None)
activation (str)
use_batch_norm (bool)
num_lstm_layers (int)
lstm_units (int | List[int] | None)
output_dim (int)
- compile(optimizer, loss=None, **kwargs)[source]¶
Configures the model for training.
Example:
```python model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.FalseNegatives()])
- Parameters:
optimizer – String (name of optimizer) or optimizer instance. See tf.keras.optimizers.
loss – Loss function. May be a string (name of loss function), or a tf.keras.losses.Loss instance. See tf.keras.losses. A loss function is any callable with the signature loss = fn(y_true, y_pred), where y_true are the ground truth values, and y_pred are the model’s predictions. y_true should have shape (batch_size, d0, .. dN) (except in the case of sparse loss functions such as sparse categorical crossentropy which expects integer arrays of shape (batch_size, d0, .. dN-1)). y_pred should have shape (batch_size, d0, .. dN). The loss function should return a float tensor. If a custom Loss instance is used and reduction is set to None, return value has shape (batch_size, d0, .. dN-1) i.e. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses, unless loss_weights is specified.
metrics – List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=[‘accuracy’]. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={‘output_a’:’accuracy’, ‘output_b’:[‘accuracy’, ‘mse’]}. You can also pass a list to specify a metric or a list of metrics for each output, such as metrics=[[‘accuracy’], [‘accuracy’, ‘mse’]] or metrics=[‘accuracy’, [‘accuracy’, ‘mse’]]. When you pass the strings ‘accuracy’ or ‘acc’, we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the shapes of the targets and of the model output. We do a similar conversion for the strings ‘crossentropy’ and ‘ce’ as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the weighted_metrics argument instead.
loss_weights – Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model’s outputs. If a dict, it is expected to map output names (strings) to scalar coefficients.
weighted_metrics – List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing.
run_eagerly –
Bool. If True, this Model’s logic will not be wrapped in a tf.function. Recommended to leave this as None unless your Model cannot be run inside a tf.function. run_eagerly=True is not supported when using tf.distribute.experimental.ParameterServerStrategy. Defaults to
False.
steps_per_execution – Int or ‘auto’. The number of batches to run during each tf.function call. If set to “auto”, keras will automatically tune steps_per_execution during runtime. Running multiple batches inside a single tf.function call can greatly improve performance on TPUs, when used with distributed strategies such as ParameterServerStrategy, or with small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if steps_per_execution is set to N, Callback.on_batch_begin and Callback.on_batch_end methods will only be called every N batches (i.e. before/after each tf.function execution). Defaults to 1.
jit_compile – If True, compile the model training step with XLA. [XLA](https://www.tensorflow.org/xla) is an optimizing compiler for machine learning. jit_compile is not enabled for by default. Note that jit_compile=True may not necessarily work for all models. For more information on supported operations please refer to the [XLA documentation](https://www.tensorflow.org/xla). Also refer to [known XLA issues](https://www.tensorflow.org/xla/known_issues) for more details.
pss_evaluation_shards – Integer or ‘auto’. Used for tf.distribute.ParameterServerStrategy training only. This arg sets the number of shards to split the dataset into, to enable an exact visitation guarantee for evaluation, meaning the model will be applied to each dataset element exactly once, even if workers fail. The dataset must be sharded to ensure separate workers do not process the same data. The number of shards should be at least the number of workers for good performance. A value of ‘auto’ turns on exact evaluation and uses a heuristic for the number of shards based on the number of workers. 0, meaning no visitation guarantee is provided. NOTE: Custom implementations of Model.test_step will be ignored when doing exact evaluation. Defaults to 0.
**kwargs – Arguments supported for backwards compatibility only.
- get_config()[source]¶
Returns the config of the Model.
Config is a Python dictionary (serializable) containing the configuration of an object, which in this case is a Model. This allows the Model to be be reinstantiated later (without its trained weights) from this configuration.
Note that get_config() does not guarantee to return a fresh copy of dict every time it is called. The callers should make a copy of the returned dict if they want to modify it.
Developers of subclassed Model are advised to override this method, and continue to update the dict from super(MyModel, self).get_config() to provide the proper configuration of this Model. The default config will return config dict for init parameters if they are basic types. Raises NotImplementedError when in cases where a custom get_config() implementation is required for the subclassed model.
- Returns:
Python dictionary containing the configuration of this Model.
- classmethod from_config(config)[source]¶
Creates a layer from its config.
This method is the reverse of get_config, capable of instantiating the same layer from the config dictionary. It does not handle layer connectivity (handled by Network), nor weights (handled by set_weights).
- Parameters:
config – A Python dictionary, typically the output of get_config.
- Returns:
A layer instance.
- help(**kwargs)¶
- my_params = TFT( dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1 )¶