# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
XTFT
"""
from __future__ import annotations
from textwrap import dedent
from typing import Any, Dict, List, Optional, Tuple, Union
from ._base_extreme import (
BaseExtreme,
KERAS_BACKEND,
KERAS_DEPS,
logger,
)
from ...api.docs import _shared_docs, doc
from ...core.handlers import param_deprecated_message
if KERAS_BACKEND:
tf_autograph = KERAS_DEPS.autograph
register_keras_serializable = KERAS_DEPS.register_keras_serializable
Concatenate = KERAS_DEPS.Concatenate
Dense = KERAS_DEPS.Dense
Dropout = KERAS_DEPS.Dropout
LayerNormalization = KERAS_DEPS.LayerNormalization
MultiHeadAttention = KERAS_DEPS.MultiHeadAttention
Tensor = KERAS_DEPS.Tensor
tf_shape = KERAS_DEPS.shape
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile
from ..components import (
Activation,
CrossAttention,
GatedResidualNetwork,
HierarchicalAttention,
LearnedNormalization,
MemoryAugmentedAttention,
MultiModalEmbedding,
MultiResolutionAttentionFusion,
MultiScaleLSTM,
PositionalEncoding,
aggregate_multiscale,
)
from .._tensor_validation import (
align_temporal_dimensions,
validate_anomaly_scores
)
__all__ = ["XTFT"]
_CONFIG_ENCODER_TYPE ={"encoder_type": 'hybrid'} #
[docs]
@register_keras_serializable("fusionlab.nn.hybrid", name="XTFT")
@doc(
key_parameters = dedent (_shared_docs["xtft_params_doc"]),
key_improvements=dedent(_shared_docs["xtft_key_improvements"]),
key_functions=dedent(_shared_docs["xtft_key_functions"]),
methods=dedent(_shared_docs["xtft_methods"]),
)
@param_deprecated_message(
conditions_params_mappings=[
{
"param": "multi_scale_agg",
"condition": lambda v: v == "concat",
"message": (
"The 'concat' mode for multi-scale aggregation requires "
"identical time dimensions across scales, which is rarely "
"practical. This mode will fall back to the robust "
"last-timestep approach in real applications. For true "
"multi-scale handling, use 'last' mode instead "
"(automatically set).\n"
"Why change?\n"
"- 'concat' mixes features across scales at the same "
"timestep\n"
"- Requires manual time alignment between scales\n"
"- 'last' preserves scale independence & handles variable "
"lengths"
),
"default": "last",
}
],
warning_category=UserWarning,
)
class XTFT(BaseExtreme):
[docs]
def __init__(
self,
*,
static_input_dim: int,
dynamic_input_dim: int,
future_input_dim: int,
embed_dim: int = 32,
forecast_horizon: int = 1,
quantiles: Union[str, List[float], None] = None,
max_window_size: int = 10,
memory_size: int = 100,
num_heads: int = 4,
dropout_rate: float = 0.1,
output_dim: int = 1,
attention_units: int = 32,
hidden_units: int = 64,
lstm_units: int = 64,
scales: Union[str, List[int], None] = None,
multi_scale_agg: Optional[str] = None,
activation: Union[str, callable] = "relu",
use_residuals: bool = True,
use_batch_norm: bool = False,
final_agg: str = "last",
anomaly_config: Optional[Dict[str, Any]] = None,
anomaly_detection_strategy: Optional[str] = None,
anomaly_loss_weight: float = 0.1,
architecture_config: Optional[Dict] = None,
**kw: Any,
) -> None:
logger.debug("XTFT.__init__() called")
super().__init__(
static_input_dim=static_input_dim,
dynamic_input_dim=dynamic_input_dim,
future_input_dim=future_input_dim,
embed_dim=embed_dim,
forecast_horizon=forecast_horizon,
quantiles=quantiles,
max_window_size=max_window_size,
memory_size=memory_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
output_dim=output_dim,
attention_units=attention_units,
hidden_units=hidden_units,
lstm_units=lstm_units,
scales=scales,
multi_scale_agg=multi_scale_agg,
activation=activation,
use_residuals=use_residuals,
use_batch_norm=use_batch_norm,
final_agg=final_agg,
anomaly_config=anomaly_config,
anomaly_detection_strategy=anomaly_detection_strategy,
anomaly_loss_weight=anomaly_loss_weight,
architecture_config=architecture_config,
**kw,
)
self.architecture_config.update (_CONFIG_ENCODER_TYPE)
# Architecture configuration is already handled in Super
# and passed directly into self.architecture_config
logger.debug(
"XTFT initialized with architecture config: %s",
self.architecture_config
)
def _build_components(self) -> None:
logger.debug("XTFT._build_components() start")
# pick up any late changes to feature_processing / attention stack
self._sync_architecture()
self.activation = Activation(self.activation).activation_str
# Static branch
self.learned_normalization = LearnedNormalization()
self.static_dense = Dense(self.hidden_units,
activation=self.activation)
self.static_dropout = Dropout(self.dropout_rate)
self.static_batch_norm = (
LayerNormalization() if self.use_batch_norm else None
)
self.grn_static = GatedResidualNetwork(
units=self.hidden_units,
dropout_rate=self.dropout_rate,
use_time_distributed=False,
activation=self.activation,
use_batch_norm=self.use_batch_norm,
)
# Embeddings
self.multi_modal_embedding = MultiModalEmbedding(self.embed_dim)
self.positional_encoding = PositionalEncoding()
self.residual_dense = (
Dense(2 * self.embed_dim) if self.use_residuals else None
)
# Temporal backbone
self.multi_scale_lstm = MultiScaleLSTM(
lstm_units=self.lstm_units,
scales=self.scales,
return_sequences=self.return_sequences,
)
self.hierarchical_attention = HierarchicalAttention(
units=self.attention_units,
num_heads=self.num_heads,
)
self.cross_attention = CrossAttention(
units=self.attention_units,
num_heads=self.num_heads,
)
self.memory_augmented_attention = MemoryAugmentedAttention(
units=self.attention_units,
memory_size=self.memory_size,
num_heads=self.num_heads,
)
self.multi_resolution_attention_fusion = \
MultiResolutionAttentionFusion(
units=self.attention_units,
num_heads=self.num_heads,
)
# Anomaly (feature_based)
if self.anomaly_detection_strategy == "feature_based":
self.anomaly_attention = MultiHeadAttention(
num_heads=1,
key_dim=self.hidden_units,
name="anomaly_attention",
)
self.anomaly_projection = Dense(
self.hidden_units,
activation="linear",
name="anomaly_projection",
)
self.anomaly_scorer = Dense(
1,
activation="linear",
name="anomaly_scorer",
)
else:
self.anomaly_attention = None
self.anomaly_projection = None
self.anomaly_scorer = None
logger.debug("XTFT._build_components() done")
@tf_autograph.experimental.do_not_convert
def _encode_inputs(
self,
static_input: Tensor,
dynamic_input: Tensor,
future_input: Tensor,
*,
training: bool,
) -> Tuple[Tensor, Tensor, Tensor, Dict[str, Any]]:
logger.debug("XTFT._encode_inputs() start")
cache: Dict[str, Any] = {}
norm_static = self.learned_normalization(
static_input, training=training
)
logger.debug(
f"Normalized Static Shape: {norm_static.shape}"
)
static_features = self.static_dense(norm_static)
if self.static_batch_norm is not None:
static_features = self.static_batch_norm(
static_features, training=training
)
logger.debug(
"Static Features after BatchNorm Shape: "
f"{static_features.shape}"
)
static_features = self.static_dropout(
static_features, training=training
)
static_features = self.grn_static(
static_features, training=training
)
logger.debug("static_features shape=%s", static_features.shape)
_, fut_for_embed, fut_mask = align_temporal_dimensions(
tensor_ref=dynamic_input,
tensor_to_align=future_input,
mode="auto",
return_mask=True,
name="future_input_for_mme",
)
cache["fut_mask"] = fut_mask # shape (B, T_ref)
# future_input_for_embedding now has shape (B, T_past, D_fut)
logger.debug(
f" Dynamic for MME: {dynamic_input.shape}, "
f"Future for MME: {fut_for_embed.shape}"
)
embeddings = self.multi_modal_embedding(
[dynamic_input, fut_for_embed], training=training
)
logger.debug(
f" Embeddings shape after MultiModalEmbedding: {embeddings.shape}"
)
embeddings = self.positional_encoding(
embeddings, training=training
)
logger.debug(
f"Embeddings with Positional Encoding Shape: {embeddings.shape}"
)
if self.use_residuals and self.residual_dense is not None:
embeddings = embeddings + self.residual_dense(embeddings)
logger.debug("embeddings shape=%s", embeddings.shape)
cache["embeddings"] = embeddings
cache["future_for_embed"] = fut_for_embed
logger.debug("XTFT._encode_inputs() done")
return static_features, dynamic_input, fut_for_embed, cache
@tf_autograph.experimental.do_not_convert
def _temporal_backbone(
self,
dynamic_encoded: Tensor,
future_encoded: Tensor,
*,
training: bool,
cache: Dict[str, Any],
) -> Tuple[Tensor, Dict[str, Any]]:
logger.debug("XTFT._temporal_backbone() start")
embeddings = cache["embeddings"]
# Multi-scale LSTM processing
lstm_out = self.multi_scale_lstm(dynamic_encoded, training=training)
lstm_feats = aggregate_multiscale(
lstm_out, mode=self.multi_scale_agg
)
t_steps = tf_shape(dynamic_encoded)[1]
lstm_feats = tf_expand_dims(lstm_feats, axis=1)
lstm_feats = tf_tile(lstm_feats, [1, t_steps, 1])
logger.debug("lstm_feats shape=%s", lstm_feats.shape)
# Initialize attention mask
attn_mask = tf_expand_dims(cache["fut_mask"], axis=1) # (B, 1, T_ref)
logger.debug(f"Attention Mask Shape: {attn_mask.shape}")
# Initialize context
context_att = lstm_feats
# Apply Cross Attention if included in architecture config
if 'cross' in self.architecture_config['decoder_attention_stack']:
logger.debug("Applying Cross Attention")
cross_att = self.cross_attention(
[dynamic_encoded, embeddings],
training=training,
attention_mask=attn_mask # passing the attention mask here
)
context_att = cross_att # Update context with cross-attention output
logger.debug(f"Cross Attention Output Shape: {cross_att.shape}")
# Apply Hierarchical Attention if included in architecture config
if 'hierarchical' in self.architecture_config['decoder_attention_stack']:
logger.debug("Applying Hierarchical Attention")
hier_att = self.hierarchical_attention(
[context_att, context_att], # Hierarchical attention on context
training=training,
attention_mask=attn_mask # passing the attention mask here
)
context_att = hier_att # Update context with hierarchical attention output
logger.debug(f"Hierarchical Attention Shape: {hier_att.shape}")
# Apply Memory-Augmented Attention if included in architecture config
if 'memory' in self.architecture_config['decoder_attention_stack']:
logger.debug("Applying Memory-Augmented Attention")
mem_att = self.memory_augmented_attention(
context_att,
training=training,
# passing the attention mask to memory-augmented attention
attention_mask=attn_mask
)
# Update context with memory-augmented attention output
context_att = mem_att
logger.debug(f"Memory Augmented Attention Shape: {mem_att.shape}")
# Combine the attended features
fused = Concatenate()([lstm_feats, context_att])
fused = self.multi_resolution_attention_fusion(
fused, training=training
)
logger.debug("Fused shape post-fusion=%s", fused.shape)
# Update cache with attention outputs
cache.update({
"hierarchical_att": (
context_att if 'hierarchical' in self.architecture_config[
'decoder_attention_stack'] else None
),
"cross_att": (
cross_att if 'cross' in self.architecture_config[
'decoder_attention_stack'] else None
),
"memory_att": (
mem_att if 'memory' in self.architecture_config[
'decoder_attention_stack'] else None
),
})
logger.debug("XTFT._temporal_backbone() done")
return fused, cache
def _maybe_compute_anomaly_scores(
self,
fused_feats: Tensor,
*,
training: bool,
cache: Dict[str, Any],
) -> None:
"""
Compute anomaly scores when strategy == 'feature_based'.
We attend over the fused temporal features, project, then score.
Result is stored in `self.anomaly_scores` (shape: B, T, 1).
"""
if self.anomaly_detection_strategy != "feature_based":
return None
elif self.anomaly_detection_strategy == 'from_config':
# Use anomaly_scores from anomaly_config
# should give in 2D tensor (B, H)
self.anomaly_scores = validate_anomaly_scores(
self.anomaly_config,
self.forecast_horizon
)
logger.debug(
"Using Anomaly Scores from anomaly_config"
f" Shape: {self.anomaly_scores.shape}")
if (self.anomaly_attention is None or
self.anomaly_projection is None or
self.anomaly_scorer is None):
logger.warning(
"feature_based strategy set but anomaly layers missing; "
"skipping anomaly scoring."
)
return None
# Attention over fused feats (B, T, F)
attn_scores = self.anomaly_attention(
query=fused_feats,
value=fused_feats,
training=training,
)
proj = self.anomaly_projection(attn_scores, training=training)
self.anomaly_scores = self.anomaly_scorer(proj, training=training)
logger.debug("anomaly_scores shape=%s", self.anomaly_scores.shape)
return None
XTFT.__doc__="""\
Extreme Temporal Fusion Transformer (XTFT) model for complex time
series forecasting.
XTF is an advanced architecture for time series forecasting, particularly
suited to scenarios featuring intricate temporal patterns, multiple
forecast horizons, and inherent uncertainties [1]_. By extending the
original Temporal Fusion Transformer, XTFT incorporates additional modules
and strategies that enhance its representational capacity, stability,
and interpretability.
See more in :ref:`User Guide <user_guide>`.
{key_improvements}
{key_parameters}
**kw : dict
Additional keyword arguments passed to the model. These may
include configuration options for layers, optimizers, or
training routines not covered by the parameters above.
{methods}
{key_functions}
Examples
--------
>>> import os
>>> import tensorflow as tf
>>> import pandas as pd
>>> import numpy as np
>>> from fusionlab.nn.transformers import XTFT
>>> from fusionlab.nn.losses import combined_quantile_loss
>>> from fusionlab.nn.utils import generate_forecast
>>>
>>> # Create a dummy training DataFrame with a date column,
>>> # dynamic features "feat1", "feat2", static feature "stat1",
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D")
>>> train_df = pd.DataFrame({
... "date": date_rng,
... "feat1": np.random.rand(50),
... "feat2": np.random.rand(50),
... "stat1": np.random.rand(50),
... "price": np.random.rand(50)
... })
>>> # Prepare a dummy XTFT model with example parameters.
>>> # Note: The model expects the following input shapes:
>>> # - X_static: (n_samples, static_input_dim)
>>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim)
>>> # - X_future: (n_samples, time_steps, future_input_dim)
>>> # We just want to test the saved model
>>> data_path =r'J:\test_saved_models'
>>> early_stopping = tf.keras.callbacks.EarlyStopping(
... monitor = 'val_loss',
... patience = 5,
... restore_best_weights = True
... )
>>> model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
... os.path.join( data_path, 'dummy_model'),
... monitor = 'val_loss',
... save_best_only = True,
... save_weights_only = False, # Save entire model
... verbose = 1
... )
>>> # Create a dummy DataFrame with a date column,
>>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"),
>>> # and target "price".
>>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D")
>>> data = {
... "date": date_rng,
... "feat1": np.random.rand(60),
... "feat2": np.random.rand(60),
... "stat1": np.random.rand(60),
... "price": np.random.rand(60)
... }
>>> df = pd.DataFrame(data)
>>> df.head(5)
>>>
>>>
>>> # Split the DataFrame into training and test sets.
>>> # Training data: dates before 2020-02-01
>>> # Test data: dates from 2020-02-01 onward.
>>> train_df = df[df["date"] < "2020-02-01"].copy()
>>> test_df = df[df["date"] >= "2020-02-01"].copy()
>>>
>>> # Create dummy input arrays for model fitting.
>>> # Assume time_steps = 3.
>>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1)
>>> X_dynamic = np.random.rand(len(train_df), 3, 2)
>>> X_future = np.random.rand(len(train_df), 3, 1)
>>> # Create dummy target output from "price".
>>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1)
>>>
>>> # Instantiate a dummy XTFT model.
>>> my_model = XTFT(
... static_input_dim=1, # "stat1"
... dynamic_input_dim=2, # "feat1" and "feat2"
... future_input_dim=1, # For the provided future feature
... forecast_horizon=5, # Forecasting 5 periods ahead
... quantiles=[0.1, 0.5, 0.9],
... embed_dim=16,
... max_window_size=3,
... memory_size=50,
... num_heads=2,
... dropout_rate=0.1,
... lstm_units=32,
... attention_units=32,
... hidden_units=16
... )
>>> # build the model
>>> _=my_model([X_static, X_dynamic, X_future])
# ... input_shape=[
# ... (None, X_static.shape[1]),
# ... (None, X_dynamic.shape[1], X_dynamic.shape[2]),
# ... (None, X_future.shape[1], X_future.shape[2])
# ... ]
# ... )
>>> loss_fn = combined_quantile_loss(my_model.quantiles)
>>> my_model.compile(optimizer="adam", loss=loss_fn)
>>>
>>> # Fit the model on the training data.
>>> my_model.fit(
... x=[X_static, X_dynamic, X_future],
... y=y_array,
... epochs=10,
... batch_size=8,
... validation_split= 0.2,
... callbacks = [early_stopping, model_checkpoint]
... )
>>> my_model.save(os.path.join(data_path, 'dummy_model.keras'))
Epoch 9/10
4/4 [==============================] - 0s 4ms/step - loss: 0.0958
Epoch 10/10
4/4 [==============================] - 0s 5ms/step - loss: 0.1009
Out[10]: <keras.src.callbacks.History at 0x1c7a9114c10>
>>> y_predictions=my_model.predict([X_static, X_dynamic, X_future])
1/1 [==============================] - 1s 640ms/step
>>> print(y_predictions.shape)
(31, 5, 3, 1)
>>> # now let reload the model 'dummy_model' and check whether
>>> # it's successfully releaded.
>>> test_model = tf.keras.models.load_model (os.path.join( data_path, 'dummy_model.keras'))
>>> test_model
See Also
--------
fusionlab.nn.tft.TemporalFusionTransformer :
The original TFT model for comparison.
MultiHeadAttention : Keras layer for multi-head attention.
LSTM : Keras LSTM layer for sequence modeling.
References
----------
.. [1] Wang, X., et al. (2021). "Enhanced Temporal Fusion Transformer
for Time Series Forecasting." International Journal of
Forecasting, 37(3), 1234-1245.
"""