Source code for fusionlab.nn.transformers._tft
# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""Implements the Temporal Fusion Transformer (TFT), a state-of-the-art
architecture for multi-horizon time-series forecasting.
"""
from numbers import Real, Integral
import logging
from typing import Optional, List, Any, Union
from ..._fusionlog import fusionlog
from ...api.docs import (
_shared_nn_params,
_shared_docs,
DocstringComponents,
)
from ...api.property import NNLearner
from ...core.checks import is_iterable
from ...core.diagnose_q import validate_quantiles
from ...core.handlers import param_deprecated_message
from ...compat.sklearn import validate_params, Interval, StrOptions
from ...decorators import Appender
from ...utils.deps_utils import ensure_pkg
from .. import KERAS_DEPS, KERAS_BACKEND, dependency_message
if KERAS_BACKEND:
from .._tensor_validation import (
validate_model_inputs,
align_temporal_dimensions
)
from ..components import (
StaticEnrichmentLayer,
PositionalEncoding,
GatedResidualNetwork,
VariableSelectionNetwork ,
TemporalAttentionLayer,
)
LSTM = KERAS_DEPS.LSTM
LayerNormalization = KERAS_DEPS.LayerNormalization
TimeDistributed = KERAS_DEPS.TimeDistributed
MultiHeadAttention = KERAS_DEPS.MultiHeadAttention
Model = KERAS_DEPS.Model
BatchNormalization = KERAS_DEPS.BatchNormalization
Input = KERAS_DEPS.Input
Softmax = KERAS_DEPS.Softmax
Flatten = KERAS_DEPS.Flatten
Dropout = KERAS_DEPS.Dropout
Dense = KERAS_DEPS.Dense
Embedding =KERAS_DEPS.Embedding
Concatenate=KERAS_DEPS.Concatenate
Layer = KERAS_DEPS.Layer
register_keras_serializable=KERAS_DEPS.register_keras_serializable
tf_reduce_sum = KERAS_DEPS.reduce_sum
tf_stack = KERAS_DEPS.stack
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile
tf_range=KERAS_DEPS.range
tf_concat = KERAS_DEPS.concat
tf_shape = KERAS_DEPS.shape
tf_rank=KERAS_DEPS.rank
tf_autograph=KERAS_DEPS.autograph
tf_autograph.set_verbosity(0)
DEP_MSG = dependency_message('transformers')
__all__ = ["TemporalFusionTransformer", "DummyTFT"]
_param_docs = DocstringComponents.from_nested_components(
base=DocstringComponents(_shared_nn_params),
)
[docs]
@Appender(_shared_docs['tft_math_doc'], join='\n', indents=0)
@param_deprecated_message(
conditions_params_mappings=[
{
'param': 'future_input_dim',
'condition': lambda v: v is not None,
'message': (
"The 'future_input_dim' parameter is accepted by "
"DummyTFT for API consistency with other TFT models, "
"but DummyTFT does not utilize future covariates. "
"This parameter will be ignored and effectively treated "
"as None. If you need to use known future covariates, "
"please consider using the standard "
":class:`~fusionlab.nn.transformers.TemporalFusionTransformer` "
"(for flexible input handling) or "
":class:`~fusionlab.nn.transformers.TFT` (for stricter "
"three-input structure)."
),
'default': None
}
]
)
@register_keras_serializable(
'fusionlab.nn.transformers',
name="DummyTFT"
)
class DummyTFT(Model, NNLearner):
"""
DummyTFT: Simplified TFT variant using only Static and Dynamic inputs.
"""
[docs]
@validate_params({
"static_input_dim": [Interval(Integral, 0, None, closed='left')],
"dynamic_input_dim": [Interval(Integral, 1, None, closed='left')],
"hidden_units": [Interval(Integral, 1, None, closed='left')],
"num_heads": [Interval(Integral, 1, None, closed='left')],
"dropout_rate": [Interval(Real, 0, 1, closed="both")],
"forecast_horizon": [Interval(Integral, 1, None, closed='left')],
"quantiles": ['array-like', None],
"activation": [StrOptions(
{"elu", "relu", "tanh", "sigmoid", "linear", "gelu"})],
"use_batch_norm": [bool],
"num_lstm_layers": [Interval(Integral, 1, None, closed='left')],
"lstm_units": ['array-like', Interval(
Integral, 1, None, closed='left'), None],
"output_dim": [Interval(Integral, 1, None, closed='left')],
},
)
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
dynamic_input_dim: int,
static_input_dim: int,
future_input_dim: Any=None,
hidden_units: int = 32,
num_heads: int = 4,
dropout_rate: float = 0.1,
forecast_horizon: int = 1,
quantiles: Optional[List[float]] = None,
activation: str = 'elu',
use_batch_norm: bool = False,
num_lstm_layers: int = 1,
lstm_units: Optional[Union[int, List[int]]] = None,
output_dim: int = 1,
name: Optional[str] = None,
**kwargs
):
if name is None:
name = "DummyTFT" # Default name
super().__init__(name=name, **kwargs)
# Initialize Logger
if fusionlog:
self.logger = fusionlog().get_fusionlab_logger(
f"{__name__}.{self.__class__.__name__}"
)
else:
self.logger = logging.getLogger(
f"{__name__}.{self.__class__.__name__}"
)
self.logger.info(
"fusionlog not found, using standard Python logging."
)
self.static_input_dim = static_input_dim
self.dynamic_input_dim = dynamic_input_dim
self.future_input_dim = future_input_dim # Store as None
self.hidden_units = hidden_units
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.forecast_horizon = forecast_horizon
self.activation_str = activation
self.use_batch_norm = use_batch_norm
self.num_lstm_layers = num_lstm_layers
self._lstm_units_config = lstm_units
self.output_dim = output_dim # Store output_dim
self.quantiles = quantiles
# Process quantiles
if self.quantiles is None:
# For point forecast, output layer produces
# output_dim features
self.num_output_features = self.output_dim
else:
self.quantiles = validate_quantiles(quantiles)
# For quantile, each quantile predicts output_dim features
self.num_output_features = len(self.quantiles) * self.output_dim
self.logger.debug(
f"Initializing {self.name} with output_dim={self.output_dim}"
f", num_quantiles_for_output_layer_units="
f"{self.num_output_features if self.quantiles else self.output_dim}"
f"hidden_units={hidden_units}, "
f"num_heads={num_heads}, "
f"dropout_rate={dropout_rate}, "
f"forecast_horizon={forecast_horizon}, "
f"quantiles={quantiles}, "
f"activation={activation}, "
f"use_batch_norm={use_batch_norm}, "
f"num_lstm_layers={num_lstm_layers}, "
f"lstm_units={lstm_units}"
)
# --- Layer Definitions ---
self.logger.debug(
f" Initializing internal layers for {self.name}...")
# Variable Selection Networks
self.static_var_sel = VariableSelectionNetwork(
num_inputs=static_input_dim, units=hidden_units,
dropout_rate=dropout_rate, activation=self.activation_str,
use_batch_norm=use_batch_norm, name="static_vsn"
)
self.dynamic_var_sel = VariableSelectionNetwork(
num_inputs=dynamic_input_dim, units=hidden_units,
dropout_rate=dropout_rate, use_time_distributed=True,
activation=self.activation_str, use_batch_norm=use_batch_norm,
name="dynamic_vsn"
)
self.positional_encoding = PositionalEncoding(
name="positional_encoding"
)
# Static Context GRNs
self.static_context_grn_attn = GatedResidualNetwork(
units=hidden_units, dropout_rate=dropout_rate,
activation=self.activation_str,
use_batch_norm=use_batch_norm,
name="static_context_for_attention"
)
self.static_context_grn_enrich = GatedResidualNetwork(
units=hidden_units, dropout_rate=dropout_rate,
activation=self.activation_str,
use_batch_norm=use_batch_norm,
name="static_context_for_enrichment"
)
# Determine LSTM units for state projection
# Use units of the first LSTM layer if defined, else hidden_units
_first_lstm_units = hidden_units # Default
if isinstance(lstm_units, int):
_first_lstm_units = lstm_units
elif isinstance(lstm_units, list) and lstm_units:
_first_lstm_units = lstm_units[0]
self.static_to_lstm_state_h = Dense(
_first_lstm_units, activation=self.activation_str,
name="static_to_h"
)
self.static_to_lstm_state_c = Dense(
_first_lstm_units, activation=self.activation_str,
name="static_to_c"
)
# LSTM Encoder
self.lstm_layers = []
actual_lstm_units_list = []
if lstm_units is None:
actual_lstm_units_list = [hidden_units] * num_lstm_layers
elif isinstance(lstm_units, int):
actual_lstm_units_list = [lstm_units] * num_lstm_layers
elif is_iterable(lstm_units): # Use your is_iterable
lstm_units_list = list(lstm_units) # Ensure list
if len(lstm_units_list) != num_lstm_layers:
raise ValueError(
"Length of `lstm_units` list must match `num_lstm_layers`."
)
actual_lstm_units_list = lstm_units_list
else: # Should be caught by @validate_params
raise TypeError(
"`lstm_units` must be int, list of int, or None."
)
for i in range(num_lstm_layers):
self.lstm_layers.append(
LSTM(
units=actual_lstm_units_list[i],
return_sequences=True,
dropout=dropout_rate,
name=f'lstm_encoder_layer_{i+1}'
)
)
# Static Enrichment Layer
self.static_enrichment = StaticEnrichmentLayer(
units=hidden_units, activation=self.activation_str,
use_batch_norm=use_batch_norm, name="static_enrichment"
)
# Temporal Attention Layer
self.temporal_attention = TemporalAttentionLayer(
units=hidden_units, num_heads=num_heads,
dropout_rate=dropout_rate, activation=self.activation_str,
use_batch_norm=use_batch_norm, name="temporal_self_attention"
)
# Position-wise Feedforward GRN
self.positionwise_grn = GatedResidualNetwork(
units=hidden_units, dropout_rate=dropout_rate,
activation=self.activation_str,
use_batch_norm=use_batch_norm, name="positionwise_ff_grn"
)
# Output Layer(s)
if self.quantiles is not None:
self.logger.debug(
" Initializing Quantile Output Projection Layers..."
)
# Each Dense layer predicts all output_dims for one quantile
self.output_projection_layers = [
TimeDistributed(
Dense(self.output_dim, name=f'dense_q_{q_idx}'),
name=f'output_projection_q_{q_idx}'
)
for q_idx in range(len(self.quantiles))
]
else: # Point forecast
self.logger.debug(
" Initializing Point Output Projection Layer..."
)
self.output_projection_layer = TimeDistributed(
Dense(self.output_dim, name='dense_point_output'),
name='output_projection_point'
)
[docs]
@tf_autograph.experimental.do_not_convert
def call(self, inputs, training=False, **kwargs):
"""Forward pass for DummyTFT (Static and Dynamic inputs only)."""
self.logger.debug(
f"DummyTFT '{self.name}': Entering call method."
)
# --- Input Validation ---
# DummyTFT expects a list of 2: [static_raw, dynamic_raw]
if not isinstance(inputs, (list, tuple)) or len(inputs) != 2:
raise ValueError(
"DummyTFT expects inputs as a list/tuple of 2 elements: "
"[static_inputs, dynamic_inputs]."
)
static_raw, dynamic_raw = inputs[0], inputs[1]
# For validate_model_inputs, future_raw is None for DummyTFT
future_raw = None
if static_raw is None or dynamic_raw is None:
# This check is crucial as both are required by __init__
raise ValueError(
"Static and Dynamic inputs must be provided (not None) "
"for DummyTFT."
)
# Use validate_model_inputs.
# It expects [S, D, F] and returns (S_p, D_p, F_p)
static_input, dynamic_input, future_input_processed = \
validate_model_inputs(
inputs=[static_raw, dynamic_raw, future_raw],
static_input_dim=self.static_input_dim,
dynamic_input_dim=self.dynamic_input_dim,
future_covariate_dim=None, # DummyTFT has no future_input_dim
forecast_horizon=self.forecast_horizon,
mode='soft', # Enforce dims for S and D
model_name='tft_flex', # Treat as a strict variant
verbose= 1 if self.logger.level <= logging.DEBUG else 0
)
# future_input_processed will be None here.
if self.logger.level <= logging.DEBUG:
self.logger.debug(
f" Inputs validated. Static: {static_input.shape}, "
f"Dynamic: {dynamic_input.shape}"
)
# --- Static Pathway ---
static_selected = self.static_var_sel(
static_input, training=training
)
if hasattr(self.static_var_sel, 'variable_importances_'):
self.static_variable_importances_ = \
self.static_var_sel.variable_importances_
if self.logger.level <= logging.DEBUG:
self.logger.debug(
f" Static VSN output shape: {static_selected.shape}"
)
context_for_attention = self.static_context_grn_attn(
static_selected, training=training
)
context_for_enrichment = self.static_context_grn_enrich(
static_selected, training=training
)
initial_state_h = self.static_to_lstm_state_h(static_selected)
initial_state_c = self.static_to_lstm_state_c(static_selected)
lstm_initial_state = [initial_state_h, initial_state_c]
# --- Temporal Pathway (Dynamic Inputs Only) ---
dynamic_selected = self.dynamic_var_sel(
dynamic_input, training=training
)
if hasattr(self.dynamic_var_sel, 'variable_importances_'):
self.dynamic_variable_importances_ = \
self.dynamic_var_sel.variable_importances_
temporal_embeddings = self.positional_encoding(
dynamic_selected, training=training
)
lstm_output = temporal_embeddings
for i, layer in enumerate(self.lstm_layers):
if self.logger.level <= logging.DEBUG: # Use standard logging levels
self.logger.debug(
f" LSTM layer {i+1} input: {lstm_output.shape}"
)
if i == 0:
lstm_output = layer(
lstm_output, initial_state=lstm_initial_state,
training=training
)
else:
lstm_output = layer(lstm_output, training=training)
if self.logger.level <= logging.DEBUG:
self.logger.debug(
f" LSTM layer {i+1} output: {lstm_output.shape}"
)
enriched_sequence = self.static_enrichment(
lstm_output, context_vector=context_for_enrichment, # Pass context
training=training
)
attention_output = self.temporal_attention(
enriched_sequence, context_vector=context_for_attention,
training=training
)
final_temporal_representation = self.positionwise_grn(
attention_output, training=training
)
# Output Slice and Projection
output_features_sliced = final_temporal_representation[
:, -self.forecast_horizon:, :
]
if self.quantiles is not None:
quantile_outputs_list = [
layer(output_features_sliced, training=training)
for layer in self.output_projection_layers
]
# Each element in list is (B, H, O).
# If O=1, concat on last axis -> (B, H, Q).
# If O>1, stack on new axis -> (B, H, Q, O).
if self.output_dim == 1:
outputs = tf_concat(quantile_outputs_list, axis=-1)
else:
outputs = tf_stack(quantile_outputs_list, axis=2)
else: # Point forecast
outputs = self.output_projection_layer(
output_features_sliced, training=training
) # Shape (B, H, O)
if self.logger.level <= logging.DEBUG:
self.logger.debug(f"DummyTFT: Final output shape: {outputs.shape}")
return outputs
[docs]
def get_config(self):
config = super().get_config()
config.update({
'static_input_dim': self.static_input_dim,
'dynamic_input_dim': self.dynamic_input_dim,
'future_input_dim': self.future_input_dim,
'hidden_units': self.hidden_units,
'num_heads': self.num_heads,
'dropout_rate': self.dropout_rate,
'forecast_horizon': self.forecast_horizon,
'quantiles': self.quantiles, # Store original list or None
'activation': self.activation_str,
'use_batch_norm': self.use_batch_norm,
'num_lstm_layers': self.num_lstm_layers,
'lstm_units': self._lstm_units_config, # Store original
'output_dim': self.output_dim,
})
return config
[docs]
@register_keras_serializable(
'fusionlab.nn.transformers',
name="TemporalFusionTransformer"
)
class TemporalFusionTransformer(Model, NNLearner):
[docs]
@validate_params({
"dynamic_input_dim": [Interval(Integral, 1, None, closed='left')],
"static_input_dim" : [Interval(Integral, 0, None, closed='left'), None],
"future_input_dim" : [Interval(Integral, 0, None, closed='left'), None],
"hidden_units" : [Interval(Integral, 1, None, closed='left'), None],
"num_heads" : [Interval(Integral, 1, None, closed='left')],
"dropout_rate" : [Interval(Real, 0, 1, closed="both")],
"forecast_horizon": [Interval(Integral, 1, None, closed='left')],
"quantiles" : ['array-like', None],
"activation" : [StrOptions({"elu", "relu", "tanh", "sigmoid",
"linear", "gelu"})],
"use_batch_norm" : [bool],
"num_lstm_layers" : [Interval(Integral, 1, None, closed='left')],
"lstm_units" : ['array-like', Interval(Integral, 1, None,
closed='left'), None]
})
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
dynamic_input_dim,
static_input_dim=None,
future_input_dim=None,
hidden_units=32,
num_heads=4,
dropout_rate=0.1,
forecast_horizon=1,
quantiles=None,
activation='elu',
use_batch_norm=False,
num_lstm_layers=1,
lstm_units=None,
output_dim = 1,
**kw
):
super().__init__(**kw)
self.logger = fusionlog().get_fusionlab_logger(__name__)
self.logger.debug(
"Initializing NTemporalFusionTransformer with parameters: "
f"static_input_dim={static_input_dim}, "
f"dynamic_input_dim={dynamic_input_dim}, "
f"future_input_dim={future_input_dim}, "
f"hidden_units={hidden_units}, "
f"num_heads={num_heads}, "
f"dropout_rate={dropout_rate}, "
f"forecast_horizon={forecast_horizon}, "
f"quantiles={quantiles}, "
f"activation={activation}, "
f"use_batch_norm={use_batch_norm}, "
f"num_lstm_layers={num_lstm_layers}, "
f"lstm_units={lstm_units}"
)
self.static_input_dim = static_input_dim
self.dynamic_input_dim = dynamic_input_dim
self.future_input_dim = future_input_dim
self.hidden_units = hidden_units
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.forecast_horizon = forecast_horizon
self.quantiles = quantiles
self.use_batch_norm = use_batch_norm
self.num_lstm_layers = num_lstm_layers
self.lstm_units = lstm_units
# Convert string activation to a Keras Activation layer
# for uniform usage across sub-layers.
self.activation = activation
# If quantiles are not provided, interpret as single output
# (deterministic). Otherwise, parse and store them.
if quantiles is None:
self.quantiles = None
self.num_quantiles = 1
else:
self.quantiles = validate_quantiles(quantiles)
self.num_quantiles = len(self.quantiles)
# Initialize variable selection networks for static, dynamic,
# and optionally future inputs.
self.logger.debug("Initializing Variable Selection Networks...")
# For static inputs (metadata)
self.static_var_sel = (
VariableSelectionNetwork(
num_inputs=static_input_dim,
units=hidden_units,
dropout_rate=dropout_rate,
activation=self.activation,
use_batch_norm=use_batch_norm
) if static_input_dim else None
)
# For dynamic (past) inputs
self.dynamic_var_sel = VariableSelectionNetwork(
num_inputs=dynamic_input_dim,
units=hidden_units,
dropout_rate=dropout_rate,
use_time_distributed=True,
activation=self.activation,
use_batch_norm=use_batch_norm
)
# For future inputs (if needed)
if self.future_input_dim is not None:
self.logger.debug("Initializing Future Variable Selection Network...")
self.future_var_sel = VariableSelectionNetwork(
num_inputs=future_input_dim,
units=hidden_units,
dropout_rate=dropout_rate,
use_time_distributed=True,
activation=self.activation,
use_batch_norm=use_batch_norm
)
else:
self.future_var_sel = None
# Positional Encoding handles the time step embedding
# for the dynamic/future sequences.
self.logger.debug("Initializing Positional Encoding...")
self.positional_encoding = PositionalEncoding()
# Static Context GRNs are used to derive static context vectors
# and enrichment vectors from static inputs (if present).
self.logger.debug("Initializing Static Context GRNs...")
self.static_context_grn = (
GatedResidualNetwork(
hidden_units,
dropout_rate,
activation=self.activation,
use_batch_norm=use_batch_norm
) if static_input_dim else None
)
self.static_context_enrichment_grn = (
GatedResidualNetwork(
hidden_units,
dropout_rate,
activation=self.activation,
use_batch_norm=use_batch_norm
) if static_input_dim else None
)
# LSTM Encoder: multi-layer LSTMs to encode the historical
# (plus possibly future) time series embeddings.
self.logger.debug("Initializing LSTM Encoder Layers...")
self.lstm_layers = []
if self.lstm_units is not None:
# Convert user-supplied LSTM units to a list if not already.
self.lstm_units = is_iterable(self.lstm_units, transform=True)
for i in range(num_lstm_layers):
lstm_units_i = hidden_units
if self.lstm_units is not None and i < len(self.lstm_units):
lstm_units_i = self.lstm_units[i]
self.lstm_layers.append(
LSTM(
lstm_units_i,
return_sequences=True,
dropout=dropout_rate,
name=f'lstm_layer_{i+1}'
)
)
# Static Enrichment: merges static context with the LSTM-encoded
# dynamic embeddings prior to attention.
self.logger.debug("Initializing Static Enrichment Layer...")
self.static_enrichment = (
StaticEnrichmentLayer(
hidden_units,
activation=self.activation,
use_batch_norm=use_batch_norm
) if static_input_dim else None
)
# Temporal Attention Layer for interpretability and weighting
# various time steps.
self.logger.debug("Initializing Temporal Attention Layer...")
self.temporal_attention = TemporalAttentionLayer(
hidden_units,
num_heads,
dropout_rate,
activation=self.activation,
use_batch_norm=use_batch_norm
)
# Position-wise Feedforward (GRN) for final transformation
# after attention.
self.logger.debug("Initializing Position-wise Feedforward Network...")
self.positionwise_grn = GatedResidualNetwork(
hidden_units,
dropout_rate,
# use_time_distributed=True,
activation=self.activation,
use_batch_norm=use_batch_norm
)
# Output Layers for either multiple quantiles or a single
# deterministic point forecast.
if self.quantiles is not None:
self.logger.debug("Initializing Quantile Output Layers...")
self.quantile_outputs = [
TimeDistributed(Dense(1), name=f'quantile_output_{i+1}')
for i in range(self.num_quantiles)
]
else:
self.logger.debug(
"Initializing Single Output Layer for Point Predictions...")
self.output_layer = TimeDistributed(Dense(1), name='output_layer')
[docs]
@tf_autograph.experimental.do_not_convert
def call(self, inputs, training=False, **kw):
"""
The main forward pass for NTemporalFusionTransformer.
1. Validate and unpack `inputs` using `validate_tft_inputs`.
2. Apply variable selection to static, dynamic, and future inputs.
3. Perform positional encoding on dynamic+future sequences.
4. Compute static context vectors if static is present.
5. Pass through LSTM encoders.
6. Optionally enrich dynamic with static context.
7. Temporal attention for interpretable weighting of time steps.
8. Position-wise feedforward (GRN).
9. Final slicing (forecast horizon) and output (quantiles or single).
Parameters
----------
inputs : tuple
Should contain up to three elements:
(dynamic_inputs, future_inputs, static_inputs)
or fewer if not all are provided.
training : bool, default=False
Whether in training mode (affects dropout, BN, etc.).
Returns
-------
tf.Tensor
Final predicted sequences of shape
(batch_size, forecast_horizon, num_quantiles or 1).
"""
self.logger.debug("Starting call method with inputs.")
# Use the validation function to unify shapes and optionally
# convert them to tf.float32. The function returns
# `validate_model_inputs` expects [S,D,F] and returns (S,D,F)
static_inputs, dynamic_inputs, future_inputs = validate_model_inputs(
inputs=inputs,
static_input_dim=self.static_input_dim,
dynamic_input_dim=self.dynamic_input_dim,
future_covariate_dim=self.future_input_dim,
forecast_horizon=self.forecast_horizon,
model_name='tft_flex',
mode="soft",
)
# 1. Apply Variable Selection on static (if present).
self.logger.debug("Applying Variable Selection on Static Inputs (if any).")
if self.static_input_dim and self.static_var_sel:
static_embedding = self.static_var_sel(static_inputs, training=training)
self.static_variable_importances_ = self.static_var_sel.variable_importances_
else:
static_embedding = None
# 2. Variable Selection on dynamic (past) inputs.
self.logger.debug("Applying Variable Selection on Dynamic Inputs...")
dynamic_embedding = self.dynamic_var_sel(dynamic_inputs, training=training)
self.dynamic_variable_importances_ = self.dynamic_var_sel.variable_importances_
# 3. If future inputs exist, apply future var selection and
# concatenate with dynamic.
if self.future_input_dim and self.future_var_sel:
self.logger.debug("Applying Variable Selection on Future Inputs...")
fut_embed = self.future_var_sel(future_inputs, training=training)
self.future_variable_importances_ = self.future_var_sel.variable_importances_
# dynamic_input is the reference for T_past (lookback period).
self.logger.debug(" Aligning temporal inputs for MultiModalEmbedding...")
_, fut_embed = align_temporal_dimensions(
tensor_ref=dynamic_embedding, # Shape (B, T_past, D_dyn)
tensor_to_align=fut_embed, # Shape (B, T_future_total, D_fut)
mode='slice_to_ref', # Slice future if longer
name="future_input_for_mme"
)
dynamic_embedding = tf_concat([dynamic_embedding, fut_embed], axis=1)
# 4. Positional encoding for combined sequence embedding
# (dynamic + future).
self.logger.debug("Applying Positional Encoding to dynamic/future Embedding...")
dynamic_embedding = self.positional_encoding(
dynamic_embedding,
training=training
)
# 5. Compute static context if static_embedding is available.
if static_embedding is not None:
self.logger.debug("Generating Static Context Vector & Enrichment...")
static_context_vector = self.static_context_grn(
static_embedding, training=training
)
static_enrichment_vector = self.static_context_enrichment_grn(
static_embedding, training=training
)
else:
static_context_vector = None
static_enrichment_vector = None
# 6. Pass embeddings through multi-layer LSTM encoders.
self.logger.debug("Passing through LSTM Encoder Layers...")
x = dynamic_embedding
for lstm_layer in self.lstm_layers:
x = lstm_layer(x, training=training)
# 7. Static enrichment merges the static context with LSTM output
# if static is present.
if self.static_enrichment and static_enrichment_vector is not None:
self.logger.debug("Applying Static Enrichment...")
enriched_lstm_output = self.static_enrichment(
x, static_enrichment_vector, training=training
)
else:
enriched_lstm_output = x
# 8. Temporal Attention for interpretable weighting of time steps,
# possibly using static context as well.
self.logger.debug("Applying Temporal Attention...")
attention_output = self.temporal_attention(
enriched_lstm_output,
context_vector=static_context_vector,
training=training
)
# 9. Position-wise feedforward transforms the attended output.
self.logger.debug("Applying Position-wise Feedforward...")
temporal_feature = self.positionwise_grn(
attention_output, training=training
)
# 10. Slice the last `forecast_horizon` steps for final prediction
# and apply output projection (either quantiles or point).
self.logger.debug("Generating Final Output...")
decoder_steps = self.forecast_horizon
outputs = temporal_feature[:, -decoder_steps:, :]
if self.quantiles:
self.logger.debug("Generating Quantile Outputs...")
quantile_outs = []
for i, qout in enumerate(self.quantile_outputs):
out_i = qout(outputs)
quantile_outs.append(out_i)
final_output = Concatenate(
axis=-1, name='final_quantile_output')(quantile_outs)
else:
self.logger.debug("Generating Point Predictions...")
final_output = self.output_layer(outputs)
self.logger.debug("Call method completed.")
return final_output
[docs]
def get_config(self):
"""
Return the model configuration for serialization.
Includes all hyperparameters that define the structure
of the NTemporalFusionTransformer.
"""
config = super().get_config().copy()
config.update({
'dynamic_input_dim' : self.dynamic_input_dim,
'static_input_dim' : self.static_input_dim,
'future_input_dim' : self.future_input_dim,
'hidden_units' : self.hidden_units,
'num_heads' : self.num_heads,
'dropout_rate' : self.dropout_rate,
'forecast_horizon' : self.forecast_horizon,
'quantiles' : self.quantiles,
'activation' : self.activation,
'use_batch_norm' : self.use_batch_norm,
'num_lstm_layers' : self.num_lstm_layers,
'lstm_units' : self.lstm_units,
})
self.logger.debug("Configuration for get_config has been updated.")
return config
[docs]
@classmethod
def from_config(cls, config):
"""
Recreate NTemporalFusionTransformer instance from config.
This classmethod is invoked by Keras to deserialize the model.
"""
cls.logger = fusionlog().get_fusionlab_logger(__name__)
cls.logger.debug("Creating NTemporalFusionTransformer instance from config.")
return cls(**config)
DummyTFT.__doc__+="""\
The DummyTFT combines high-performance multi-horizon
forecasting with interpretable insights into temporal dynamics [1]_.
It integrates several advanced mechanisms, including:
- Variable Selection Networks (VSNs) for static and dynamic features.
- Gated Residual Networks (GRNs) for processing inputs.
- Static Enrichment Layer to incorporate static features into temporal
processing.
- LSTM Encoder for capturing sequential dependencies.
- Temporal Attention Layer for focusing on important time steps.
- Position-wise Feedforward Layer.
- Final Output Layer for prediction.
Parameters
----------
static_input_dim : int
The input dimension per static variable. Typically ``1`` for scalar
features or higher for embeddings. This defines the number of features
for each static variable. For example, if static variables are
represented using embeddings of size 16, then ``static_input_dim``
would be ``16``.
dynamic_input_dim : int
The input dimension per dynamic variable. This defines the number of
features for each dynamic variable at each time step. For instance, if
dynamic variables are represented using embeddings or multiple features,
specify the appropriate dimension.
{params.base.hidden_units}
{params.base.num_heads}
{params.base.dropout_rate}
forecast_horizon : int, optional
The number of time steps to forecast. Default is ``1``. This parameter
defines the number of future time steps the model will predict. For
multi-step forecasting, set ``forecast_horizon`` to the desired number
of future steps.
{params.base.quantiles}
{params.base.activation}
{params.base.use_batch_norm}
num_lstm_layers : int, optional
Number of LSTM layers in the encoder. Default is ``1``. Adding more
layers can help the model capture more complex sequential patterns.
Each additional layer processes the output of the previous LSTM layer.
lstm_units : list of int or None, optional
List containing the number of units for each LSTM layer. If ``None``,
all LSTM layers have ``hidden_units`` units. Default is ``None``.
This parameter allows customizing the size of each LSTM layer. For
example, to specify different units for each layer, provide a list like
``[64, 32]``.
Methods
-------
call(inputs, training=False)
Forward pass of the model.
Parameters
----------
inputs : tuple of tensors
A tuple containing ``(static_inputs, dynamic_inputs)``.
- ``static_inputs``: Tensor of shape ``(batch_size, num_static_vars,
static_input_dim)`` representing the static features.
- ``dynamic_inputs``: Tensor of shape ``(batch_size, time_steps,
num_dynamic_vars, dynamic_input_dim)`` representing the dynamic
features.
training : bool, optional
Whether the model is in training mode. Default is ``False``.
Returns
-------
Tensor
The output predictions of the model. The shape depends on the
``forecast_horizon`` and whether ``quantiles`` are used.
get_config()
Returns the configuration of the model for serialization.
from_config(config)
Instantiates the model from a configuration dictionary.
""".format(params=_param_docs)
TemporalFusionTransformer.__doc__="""\
TemporalFusionTransformer model implementation for multi-horizon
forecasting, with optional static, past, and future inputs.
This class extends Keras `Model` and integrates with the gofast
NNLearner interface. It supports dynamic (past) inputs, optional
static inputs, and newly added optional future inputs
(``future_input_dim``). By including the future covariates, the
TemporalFusionTransformer can account for known future features
(e.g., events, planned discount rates, etc.) in its predictions.
Parameters
----------
dynamic_input_dim: int
Dimensionality of the dynamic (past) inputs. This is mandatory
for the TFT model.
static_input_dim : int, optional
Dimensionality of static inputs. If not `None`, the call method
will expect static inputs.
future_input_dim : int, optional
Dimensionality of future (known) inputs. If not `None`, the call
method will expect future inputs to handle exogenous covariates
known in the future (e.g., events, planned promotions, etc.).
hidden_units : int, default=32
Number of hidden units for the layers that do not have a distinct
specification (e.g., GRNs, variable selection networks).
num_heads : int, default=4
Number of attention heads in the multi-head attention layer.
dropout_rate : float, default=0.1
Dropout rate for various layers (GRNs, attention, etc.).
forecast_horizon : int, default=1
Number of timesteps to forecast into the future.
quantiles : list of float, optional
List of quantiles for probabilistic forecasting. If `None`, a
single deterministic output is produced.
activation : str, default='elu'
Activation function. Must be one of ``{'elu', 'relu', 'tanh',
'sigmoid', 'linear', 'gelu'}``.
use_batch_norm : bool, default=False
Whether to apply batch normalization in various sub-layers.
num_lstm_layers : int, default=1
Number of LSTM layers in the encoder.
lstm_units : list of int or None, default=None
If provided, each index corresponds to the number of LSTM
units for that layer. If `None`, uses ``hidden_units`` for
each layer.
Examples
--------
>>> from fusionlab.nn._tensor_validation import validate_tft_inputs
>>> from fusionlab.nn.tft import TemporalFusionTransformer
>>> model = TemporalFusionTransformer(
... dynamic_input_dim=10,
... static_input_dim=5,
... future_input_dim=8,
... hidden_units=32,
... num_heads=4,
... dropout_rate=0.1,
... forecast_horizon=7,
... quantiles=[0.1, 0.5, 0.9],
... activation='elu',
... use_batch_norm=True,
... num_lstm_layers=2,
... lstm_units=[64, 32]
... )
Notes
-----
The newly added ``future_input_dim`` allows the model to incorporate
future covariates known at forecast time. In the ``call`` method, if
``future_input_dim`` is not `None`, the model expects three inputs:
``(static_inputs, dynamic_inputs, future_inputs)``. Otherwise, it
expects only ``(static_inputs, dynamic_inputs)``.
See Also
--------
VariableSelectionNetwork : For feature selection and embedding.
GatedResidualNetwork : A GRN used in various sub-layers.
LSTM : Keras LSTM layers for sequence processing.
References
----------
.. [1] Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2019).
Temporal Fusion Transformers for Interpretable
Multi-horizon Time Series Forecasting.
https://arxiv.org/abs/1912.09363
"""