Source code for fusionlab.nn.hybrid._xtft

# -*- coding: utf-8 -*-
#   License: BSD-3-Clause
#   Author: LKouadio <etanoyau@gmail.com>

"""Implements the Extreme Temporal Fusion Transformer (XTFT), a state-of-the-art 
architecture for multi-horizon time-series forecasting.
"""

from textwrap import dedent 
from numbers import Real, Integral  
from typing import List, Optional, Union, Dict, Any  
import numpy as np 

from ..._fusionlog import fusionlog, OncePerMessageFilter
from ...api.docs import _shared_docs, doc 
from ...api.property import NNLearner 
from ...compat.sklearn import validate_params, Interval, StrOptions 
from ...core.handlers import param_deprecated_message 
from ...utils.deps_utils import ensure_pkg
from ...decorators import Appender

from .. import KERAS_DEPS, KERAS_BACKEND, dependency_message

if KERAS_BACKEND:
    from ...compat.tf import optional_tf_function 
    from .._tensor_validation import validate_anomaly_scores 
    from .._tensor_validation import validate_model_inputs
    from .._tensor_validation import validate_anomaly_config 
    from .._tensor_validation import align_temporal_dimensions
    from ..losses import ( 
        combined_quantile_loss, 
        combined_total_loss, 
        prediction_based_loss
    )
    from ..utils import set_default_params
    from ..components import (
        Activation, 
        AdaptiveQuantileLoss,
        AnomalyLoss,
        CrossAttention,
        DynamicTimeWindow,
        GatedResidualNetwork,
        HierarchicalAttention,
        LearnedNormalization,
        MemoryAugmentedAttention,
        MultiDecoder,
        MultiModalEmbedding,
        MultiObjectiveLoss,
        MultiResolutionAttentionFusion,
        MultiScaleLSTM,
        QuantileDistributionModeling,
        VariableSelectionNetwork,
        PositionalEncoding, 
        aggregate_multiscale, 
        aggregate_time_window_output
        )
         
LSTM = KERAS_DEPS.LSTM
Dense = KERAS_DEPS.Dense
Flatten = KERAS_DEPS.Flatten
Dropout = KERAS_DEPS.Dropout 
Layer = KERAS_DEPS.Layer 
LayerNormalization = KERAS_DEPS.LayerNormalization 
MultiHeadAttention = KERAS_DEPS.MultiHeadAttention
Model= KERAS_DEPS.Model 
Input=KERAS_DEPS.Input
Concatenate=KERAS_DEPS.Concatenate 
Tensor=KERAS_DEPS.Tensor
register_keras_serializable=KERAS_DEPS.register_keras_serializable

tf_reduce_sum = KERAS_DEPS.reduce_sum
tf_stack = KERAS_DEPS.stack
tf_expand_dims = KERAS_DEPS.expand_dims
tf_tile = KERAS_DEPS.tile
tf_range_=KERAS_DEPS.range 
tf_concat = KERAS_DEPS.concat
tf_shape = KERAS_DEPS.shape
tf_reshape=KERAS_DEPS.reshape
tf_add = KERAS_DEPS.add
tf_maximum = KERAS_DEPS.maximum
tf_reduce_mean = KERAS_DEPS.reduce_mean
tf_add_n = KERAS_DEPS.add_n
tf_float32=KERAS_DEPS.float32
tf_constant=KERAS_DEPS.constant 
tf_square=KERAS_DEPS.square 
tf_GradientTape=KERAS_DEPS.GradientTape
tf_unstack =KERAS_DEPS.unstack
tf_errors=KERAS_DEPS.errors 
tf_is_nan =KERAS_DEPS.is_nan 
tf_reduce_all=KERAS_DEPS.reduce_all
tf_zeros_like=KERAS_DEPS.zeros_like
tf_squeeze = KERAS_DEPS.squeeze

tf_autograph=KERAS_DEPS.autograph
tf_autograph.set_verbosity(0)

DEP_MSG = dependency_message('nn.transformers') 

logger = fusionlog().get_fusionlab_logger(__name__)
logger.addFilter(OncePerMessageFilter())

__all__=["XTFT", "SuperXTFT"]

[docs] @register_keras_serializable('fusionlab.nn.transformers', name="XTFT") @doc ( key_improvements= dedent(_shared_docs['xtft_key_improvements']), key_functions= dedent(_shared_docs['xtft_key_functions']), methods= dedent( _shared_docs['xtft_methods'] ) ) @param_deprecated_message( conditions_params_mappings=[ { 'param': 'multi_scale_agg', 'condition': lambda v: v == "concat", 'message': ( "The 'concat' mode for multi-scale aggregation requires identical " "time dimensions across scales, which is rarely practical. " "This mode will fall back to the robust last-timestep approach " "in real applications. For true multi-scale handling, use 'last' " "mode instead (automatically set).\n" "Why change?\n" "- 'concat' mixes features across scales at the same timestep\n" "- Requires manual time alignment between scales\n" "- 'last' preserves scale independence & handles variable lengths" ), 'default': "last" } ], warning_category=UserWarning ) class XTFT(Model, NNLearner):
[docs] @validate_params({ "static_input_dim": [Interval(Integral, 1, None, closed='left')], "dynamic_input_dim": [Interval(Integral, 1, None, closed='left')], "future_input_dim": [Interval(Integral, 1, None, closed='left')], "embed_dim": [Interval(Integral, 1, None, closed='left')], "forecast_horizon": [Interval(Integral, 1, None, closed='left')], "quantiles": ['array-like', StrOptions({'auto'}), None], "max_window_size": [Interval(Integral, 1, None, closed='left')], "memory_size": [Interval(Integral, 1, None, closed='left')], "num_heads": [Interval(Integral, 1, None, closed='left')], "dropout_rate": [Interval(Real, 0, 1, closed="both")], "output_dim": [Interval(Integral, 1, None, closed='left')], "attention_units": [ 'array-like', Interval(Integral, 1, None, closed='left') ], "hidden_units": [ 'array-like', Interval(Integral, 1, None, closed='left') ], "lstm_units": [ 'array-like', Interval(Integral, 1, None, closed='left'), None ], "activation": [ StrOptions({"elu", "relu", "tanh", "sigmoid", "linear", "gelu"}), callable ], "multi_scale_agg": [ StrOptions({"last", "average", "flatten", "auto", "sum", "concat"}), None ], "scales": ['array-like', StrOptions({"auto"}), None], "use_batch_norm": [bool, Interval(Integral, 0, 1, closed="both")], "use_residuals": [bool, Interval(Integral, 0, 1, closed="both")], "final_agg": [StrOptions({"last", "average", "flatten"})], "anomaly_detection_strategy": [ StrOptions({"prediction_based", "feature_based", "from_config"}), None ], 'anomaly_loss_weight': [Real] }, ) @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, static_input_dim: int, dynamic_input_dim: int, future_input_dim: int, embed_dim: int = 32, forecast_horizon: int = 1, quantiles: Union[str, List[float], None] = None, max_window_size: int = 10, memory_size: int = 100, num_heads: int = 4, dropout_rate: float = 0.1, output_dim: int = 1, attention_units: int = 32, hidden_units: int = 64, lstm_units: int = 64, scales: Union[str, List[int], None] = None, multi_scale_agg: Optional[str] = None, activation: str = 'relu', use_residuals: bool = True, use_batch_norm: bool = False, final_agg: str = 'last', anomaly_config: Optional[Dict[str, Any]] = None, anomaly_detection_strategy: Optional[str] = None, anomaly_loss_weight: float =.1, **kw, ): super().__init__(**kw) self.activation = Activation(activation).activation_str logger.debug( "Initializing XTFT with parameters: " f"static_input_dim={static_input_dim}, " f"dynamic_input_dim={dynamic_input_dim}, " f"future_input_dim={future_input_dim}, " f"embed_dim={embed_dim}, " f"forecast_horizon={forecast_horizon}, " f"quantiles={quantiles}, " f"max_window_size={max_window_size}," f" memory_size={memory_size}, num_heads={num_heads}, " f"dropout_rate={dropout_rate}, output_dim={output_dim}, " f"attention_units={attention_units}, " f" hidden_units={hidden_units}, " f"lstm_units={lstm_units}, " f"scales={scales}, " f"activation={self.activation}, " f"use_residuals={use_residuals}, " f"use_batch_norm={use_batch_norm}, " f"final_agg={final_agg}" ) # Handle default quantiles, scales and multi_scale_agg quantiles, scales, return_sequences = set_default_params( quantiles, scales, multi_scale_agg ) self.static_input_dim = static_input_dim self.dynamic_input_dim = dynamic_input_dim self.future_input_dim = future_input_dim self.embed_dim = embed_dim self.forecast_horizon = forecast_horizon self.quantiles = quantiles self.max_window_size = max_window_size self.memory_size = memory_size self.num_heads = num_heads self.dropout_rate = dropout_rate self.output_dim = output_dim self.attention_units = attention_units self.hidden_units = hidden_units self.lstm_units = lstm_units self.scales = scales self.multi_scale_agg = multi_scale_agg self.use_residuals = use_residuals self.use_batch_norm = use_batch_norm self.final_agg = final_agg self.anomaly_detection_strategy=anomaly_detection_strategy self.anomaly_loss_weight=anomaly_loss_weight # Layers self.multi_modal_embedding = MultiModalEmbedding(embed_dim) # Add PositionalEncoding layer self.positional_encoding = PositionalEncoding() self.multi_scale_lstm = MultiScaleLSTM( lstm_units=self.lstm_units, scales=self.scales, return_sequences=return_sequences ) self.hierarchical_attention = HierarchicalAttention( units=attention_units, num_heads=num_heads ) self.cross_attention = CrossAttention( units=attention_units, num_heads=num_heads ) self.memory_augmented_attention = MemoryAugmentedAttention( units=attention_units, memory_size=memory_size, num_heads=num_heads ) self.multi_decoder = MultiDecoder( output_dim=output_dim, num_horizons=forecast_horizon ) self.multi_resolution_attention_fusion = MultiResolutionAttentionFusion( units=attention_units, num_heads=num_heads ) self.dynamic_time_window = DynamicTimeWindow( max_window_size=max_window_size ) self.quantile_distribution_modeling = QuantileDistributionModeling( quantiles=self.quantiles, output_dim=output_dim ) # Validate anomaly configuration self.anomaly_config, self.anomaly_detection_strategy,\ self.anomaly_loss_weight =validate_anomaly_config( anomaly_config=anomaly_config, forecast_horizon= self.forecast_horizon, strategy=anomaly_detection_strategy, default_anomaly_loss_weight=self.anomaly_loss_weight, return_loss_weight=True ) logger.debug( f"anomaly_config={self.anomaly_config.keys()}, " f"anomaly_detection_strategy={anomaly_detection_strategy}" f"anomaly_loss_weight={anomaly_loss_weight}" ) # Initialize/Fetch anomaly scores self.anomaly_scores = self.anomaly_config.get('anomaly_scores') # Anomaly scores handling self.anomaly_loss_layer = AnomalyLoss( weight=self.anomaly_loss_weight ) # Initialize anomaly detection layers if self.anomaly_detection_strategy == 'feature_based': self._init_feature_based_components() # --------------------------------------------------------------------- # The MultiObjectiveLoss encapsulates both quantile and anomaly losses # to allow simultaneous training on multiple objectives. While this # functionality can currently be bypassed, note that it may be removed # in a future release. Users who rely on multi-objective training # strategies should keep an eye on upcoming changes. # # Here, we instantiate the MultiObjectiveLoss with an adaptive quantile # loss function, which adjusts quantile estimates dynamically based on # the provided quantiles, and an anomaly loss function that penalizes # predictions deviating from expected anomaly patterns. # --------------------------------------------------------------------- self.multi_objective_loss = MultiObjectiveLoss( quantile_loss_fn=AdaptiveQuantileLoss(self.quantiles), anomaly_loss_fn=self.anomaly_loss_layer ) # --------------------------------------------------------------------- self.learned_normalization = LearnedNormalization() self.static_dense = Dense(hidden_units, activation=self.activation) self.static_dropout = Dropout(dropout_rate) if self.use_batch_norm: self.static_batch_norm = LayerNormalization() # Initialize Gated Residual Networks (GRNs) for attention outputs self.grn_static = GatedResidualNetwork( units=hidden_units, dropout_rate=dropout_rate, use_time_distributed=False, activation=self.activation, use_batch_norm=self.use_batch_norm ) self.residual_dense = Dense(2 * embed_dim) if use_residuals else None
def _init_feature_based_components(self): """ Initializes architecture components for feature-based anomaly detection. Creates: 1. Anomaly Attention: Multi-head attention layer to identify unusual patterns in feature relationships 2. Anomaly Projection: Dense layer to project the anomaly attention output to the desired dimension. 3. Anomaly Scorer: Dense layer to convert the projected features outputs to anomaly scores Design Rationale: - key_dim aligns with hidden_units for dimension compatibility - Single attention head focuses on global anomaly patterns - Linear activation preserves relative magnitude of anomaly scores """ self.anomaly_attention = MultiHeadAttention( num_heads=1, key_dim=self.hidden_units, name='anomaly_attention' ) # Projection layer to dynamically adjust the dimension. self.anomaly_projection = Dense( self.hidden_units, activation='linear', name='anomaly_projection' ) self.anomaly_scorer = Dense( 1, activation='linear', name='anomaly_scorer' )
[docs] @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False, **kwargs): """ Forward pass of the XTFT model. Parameters ---------- inputs : tuple or list Input data containing three elements: 1. Static features (batch_size, static_input_dim) 2. Dynamic historical features (batch_size, time_steps, dynamic_input_dim) 3. Future covariates (batch_size, horizon, future_input_dim) training : bool, optional Whether the model is in training mode, by default False **kwargs Additional keyword arguments Returns ------- tf.Tensor Predictions tensor of shape: - (batch_size, horizon, len(quantiles)) if quantiles specified - (batch_size, horizon, output_dim) otherwise Raises ------ ValueError If input validation fails through validate_xtft_inputs Notes ----- - Handles three types of anomaly detection strategies: 1. 'feature_based': Generates scores from attention mechanisms 2. 'prediction_based': Handled in loss function 3. 'from_config': Uses precomputed anomaly scores - Implements multi-scale temporal processing with: - Positional encoding - Hierarchical attention - Memory-augmented attention - Dynamic time windowing """ static_input , dynamic_input, future_input = validate_model_inputs ( inputs =inputs, static_input_dim=self.static_input_dim, dynamic_input_dim= self.dynamic_input_dim, future_covariate_dim= self.future_input_dim, forecast_horizon=self.forecast_horizon, mode='strict', # XTFT is generally strict model_name='xtft', # For specific validation logic if any verbose= 1 if logger.level <= 10 else 0 # DEBUG level ) # Normalize and process static features normalized_static = self.learned_normalization( static_input, training=training ) logger.debug( f"Normalized Static Shape: {normalized_static.shape}" ) # Apply -> GRN pipeline to cross attention static_features = self.static_dense(normalized_static) if self.use_batch_norm: static_features = self.static_batch_norm( static_features, training=training ) logger.debug( "Static Features after BatchNorm Shape: " f"{static_features.shape}" ) static_features = self.static_dropout( static_features, training=training ) logger.debug( f"Static Features Shape: {static_features.shape}" ) # XXX TODO # check apply --> GRN static_features = self.grn_static( static_features, training=training ) # --- Prepare inputs for MultiModalEmbedding --- # dynamic_input is the reference for T_past (lookback period). # future_input needs its time dimension aligned to dynamic_input's. logger.debug(" Aligning temporal inputs for MultiModalEmbedding...") _, future_input_for_embedding = align_temporal_dimensions( tensor_ref=dynamic_input, # Shape (B, T_past, D_dyn) tensor_to_align=future_input, # Shape (B, T_future_total, D_fut) mode='slice_to_ref', # Slice future if longer name="future_input_for_mme" ) # future_input_for_embedding now has shape (B, T_past, D_fut) logger.debug( f" Dynamic for MME: {dynamic_input.shape}, " f"Future for MME: {future_input_for_embedding.shape}" ) embeddings = self.multi_modal_embedding( [dynamic_input, future_input_for_embedding], # Pass ALIGNED inputs training=training ) # Output of MultiModalEmbedding: (B, T_past, CombinedEmbedDim) logger.debug( f" Embeddings shape after MultiModalEmbedding: {embeddings.shape}" ) # Add positional encoding to embeddings # before attention mechanisms. embeddings = self.positional_encoding( embeddings, training=training ) logger.debug( f"Embeddings with Positional Encoding Shape: {embeddings.shape}" ) if self.use_residuals: embeddings = embeddings + self.residual_dense(embeddings) logger.debug( "Embeddings with Residuals Shape: " f"{embeddings.shape}" ) # Multi-scale LSTM outputs lstm_output = self.multi_scale_lstm( dynamic_input, training=training ) # If multi_scale_agg is None, lstm_output is (B, units * len(scales)) # If multi_scale_agg is not None, lstm_output is a list of full # sequences: [ (B, T', units), ... ] lstm_features = aggregate_multiscale( lstm_output, mode= self.multi_scale_agg ) # Since we are concatenating along the time dimension, we need # all tensors to have the same shape along that dimension. time_steps = tf_shape(dynamic_input)[1] # Expand lstm_features to (B, 1, features) lstm_features = tf_expand_dims(lstm_features, axis=1) # Tile to match tf_time steps: (B, T, features) lstm_features = tf_tile(lstm_features, [1, time_steps, 1]) logger.debug( f"LSTM Features Shape: {lstm_features.shape}" ) # Attention mechanisms # For HierarchicalAttention, if it adds outputs, # inputs need same time dim. # We use dynamic_input (T_past) and the already sliced # future_input_for_embedding (T_past). logger.debug( " Aligning temporal inputs for HierarchicalAttention..." ) # No further alignment needed if using future_input_for_embedding # which is already aligned to dynamic_input's T_past. hierarchical_att = self.hierarchical_attention( [dynamic_input, future_input_for_embedding], # Both (B, T_past, Feats) training=training ) logger.debug( f"Hierarchical Attention Shape: {hierarchical_att.shape}" ) cross_attention_output = self.cross_attention( [dynamic_input, embeddings], training=training ) logger.debug( f"Cross Attention Output Shape: {cross_attention_output.shape}" ) memory_attention_output = self.memory_augmented_attention( hierarchical_att, training=training ) logger.debug( "Memory Augmented Attention Output Shape: " f"{memory_attention_output.shape}" ) # Combine all features time_steps = tf_shape(dynamic_input)[1] static_features_expanded = tf_tile( tf_expand_dims(static_features, axis=1), [1, time_steps, 1] ) logger.debug( "Static Features Expanded Shape: " f"{static_features_expanded.shape}" ) combined_features = Concatenate()([ static_features_expanded, lstm_features, cross_attention_output, hierarchical_att, memory_attention_output, ]) logger.debug( f"Combined Features Shape: {combined_features.shape}" ) attention_fusion_output = self.multi_resolution_attention_fusion( combined_features, training=training ) logger.debug( "Attention Fusion Output Shape: " f"{attention_fusion_output.shape}" ) time_window_output = self.dynamic_time_window( attention_fusion_output, training=training ) logger.debug( f"Time Window Output Shape: {time_window_output.shape}" ) # final_agg: last/average/flatten applied on time_window_output final_features = aggregate_time_window_output( time_window_output, self.final_agg ) decoder_outputs = self.multi_decoder( final_features, training=training ) logger.debug( f"Decoder Outputs Shape: {decoder_outputs.shape}" ) predictions = self.quantile_distribution_modeling( decoder_outputs, training=training ) # Anomaly detection branch if self.anomaly_detection_strategy == 'feature_based': # Compute anomaly scores from attention features attn_scores = self.anomaly_attention( query=attention_fusion_output, value=attention_fusion_output, training=training ) # Project the anomaly attention output # to the desired dimension. projected_attn = self.anomaly_projection( attn_scores, training=training ) # From config Anomaly score shape is (B, T, O) where= # Batch size, T, is Time Steps and O is output dim. # Compute anomaly scores using the projected output. self.anomaly_scores = self.anomaly_scorer( projected_attn, training=training ) elif self.anomaly_detection_strategy == 'from_config': # Use anomaly_scores from anomaly_config # should give in 2D tensor (B, H) self.anomaly_scores = validate_anomaly_scores( self.anomaly_config, self.forecast_horizon ) logger.debug( "Using Anomaly Scores from anomaly_config" f" Shape: {self.anomaly_scores.shape}") # Handle anomaly loss if self.anomaly_scores is not None: # Use provided anomaly_scores from anomaly_config # Use default zeros placeholder for y_pred with # shape (B, T, O) # shape = tf_shape(self.anomaly_scores) # default_y_pred = tf_zeros( # [shape[0], shape[1], shape[2]], # dtype=self.anomaly_scores.dtype # ) default_y_pred = tf_zeros_like(self.anomaly_scores) logger.debug( "Using Anomaly Scores from anomaly_config with" f" weight: {self.anomaly_loss_weight}.") # Define appropriate dimensions anomaly_loss = self.anomaly_loss_layer( self.anomaly_scores, default_y_pred, ) self.add_loss(self.anomaly_loss_weight * anomaly_loss) logger.debug( f"Anomaly Loss Computed and Added: {anomaly_loss}") else: # Optionally, log a warning or set a default value. logger.warning( "Anomaly scores are None. Skipping anomaly loss." ) logger.debug( f"Predictions Shape: {predictions.shape}" ) # Explicitly squeeze ONLY if quantiles were # requested AND output_dim is 1 if self.quantiles is not None and self.output_dim == 1: # Check if the tensor actually has 4 dimensions before squeezing if len(predictions.shape) == 4: final_output = tf_squeeze(predictions, axis=-1) logger.debug( f"Squeezed final quantile output dim (O=1)." f" New shape: {tf_shape(final_output)}" ) elif len(predictions.shape) == 3: # Already has shape (B, H, Q), no squeeze needed final_output = predictions else: # Unexpected shape logger.warning(f"Unexpected prediction shape before squeeze:" f" {predictions.shape}. Returning as is.") final_output = predictions elif self.quantiles is None: # Point forecast, ensure shape is (B, H, O) # (May not need specific handling # if output_layer gives correct shape) final_output = predictions logger.debug(f"TFT '{self.name}': Final returned output shape:" f" {tf_shape(final_output)}") return final_output
[docs] def compile(self, optimizer, loss=None, **kws): """ Compile the XTFT model, allowing an explicit user-specified loss to override the defaults. If the user provides a loss (loss=...), it is used regardless of quantiles or anomaly scores. Otherwise, the method uses the following logic: - If ``self.quantiles`` is None, defaults to "mse"(or the user-supplied ``loss``). - If ``self.quantiles`` is not None, uses a quantile-based loss. If ``anomaly_scores`` is present, a total loss is used that adds anomaly loss on top. See also: -------- fusionlab.nn.losses.combined_quantile_loss fusionlab.nn.losses.combined_total_loss """ # 1) If user explicitly provides a loss, respect that and skip defaults if loss is not None: super().compile( optimizer=optimizer, loss=loss, **kws ) return # 2) Handle prediction-based strategy first if self.anomaly_detection_strategy == 'prediction_based': pred_loss_fn = prediction_based_loss( quantiles=self.quantiles, anomaly_loss_weight=self.anomaly_loss_weight ) super().compile( optimizer=optimizer, loss=pred_loss_fn, **kws ) return # 3) Otherwise, we handle the default logic if self.quantiles is None: # Deterministic scenario super().compile( optimizer=optimizer, loss="mean_squared_error", **kws ) return # Probabilistic scenario with quantiles quantile_loss_fn = combined_quantile_loss(self.quantiles) # Handle from_config strategy if self.anomaly_detection_strategy == 'from_config': self.anomaly_scores = self.anomaly_config.get( "anomaly_scores") if self.anomaly_scores is not None: total_loss_fn = combined_total_loss( quantiles=self.quantiles, anomaly_layer=self.anomaly_loss_layer, anomaly_scores=self.anomaly_scores ) super().compile( optimizer=optimizer, loss=total_loss_fn, **kws ) return # Only quantile loss # Handles feature-based and other cases) super().compile( optimizer=optimizer, loss=quantile_loss_fn, **kws )
[docs] @optional_tf_function def train_step(self, data): """ Custom training step with anomaly detection strategy handling. Parameters ---------- data : tuple/tf.data.Dataset Training data containing: - For prediction-based strategy: (x, y) pairs - Other strategies: Standard Keras-compatible format Returns ------- dict Metric results dictionary Notes ----- - Special handling for prediction-based anomaly detection: - Requires explicit (x, y) pairs - Validates y_true integrity - Falls back to standard training if data format invalid - For other strategies, uses native Keras training logic Raises ------ Warning (logged) - For missing y_true in prediction-based mode - For invalid/nan values in y_true Example ------- >>> model.compile(...) >>> model.fit(dataset, epochs=10) """ # Handle prediction-based strategy if self.anomaly_detection_strategy == 'prediction_based': try: # Attempt to unpack (x, y) pair if isinstance(data, (list, tuple)) and len(data) >= 2: x, y = data[0], data[1] else: # For TF Dataset/other formats, try tensor split x, y = tf_unstack(data, num=2, axis=0) except (ValueError, tf_errors.InvalidArgumentError): logger.warning( "Prediction-based strategy requires (x, y) data pairs. " "Falling back to standard training step." ) return super().train_step(data) # Verify y_true contains valid values if y.shape.ndims == 0 or tf_reduce_all(tf_is_nan(y)): logger.warning( "Invalid y_true provided for prediction-based strategy. " "Contains NaN values or incorrect shape." ) return super().train_step(data) with tf_GradientTape() as tape: y_pred = self(x, training=True) loss = self.compiled_loss(y, y_pred) # Gradient updates trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Update metrics self.compiled_metrics.update_state(y, y_pred) return {m.name: m.result() for m in self.metrics} # Standard processing for other strategies return super().train_step(data)
[docs] def get_config(self): """ Get serialization configuration for model saving/loading. Returns ------- dict Complete configuration dictionary containing: - Model architecture parameters - Anomaly detection configuration - Component hyperparameters - Training configuration Notes ----- - Handles special cases for: - Quantile list serialization - Numpy array conversion for anomaly scores - Custom layer configurations - Logs configuration changes via model logger Example ------- >>> config = model.get_config() >>> json.dump(config, open('model_config.json', 'w')) """ # Retrieve the base configuration from the superclass. config = super().get_config().copy() # Update configuration with XTFT-specific parameters. config.update({ 'static_input_dim' : int(self.static_input_dim), 'dynamic_input_dim' : int(self.dynamic_input_dim), 'future_input_dim' : int(self.future_input_dim), 'embed_dim' : int(self.embed_dim), 'forecast_horizon' : int(self.forecast_horizon), 'quantiles' : (list(self.quantiles) if self.quantiles is not None else None), 'max_window_size' : int(self.max_window_size), 'memory_size' : int(self.memory_size), 'num_heads' : int(self.num_heads), 'dropout_rate' : float(self.dropout_rate), 'output_dim' : int(self.output_dim), 'attention_units' : int(self.attention_units), 'hidden_units' : int(self.hidden_units), 'lstm_units' : (int(self.lstm_units) if self.lstm_units is not None else None), 'scales' : (list(self.scales) if self.scales is not None else None), 'activation' : self.activation, 'use_residuals' : bool(self.use_residuals), 'use_batch_norm' : bool(self.use_batch_norm), 'final_agg' : self.final_agg, 'multi_scale_agg' : (str(self.multi_scale_agg) if self.multi_scale_agg is not None else None), 'anomaly_config' : { 'anomaly_loss_weight': ( float(self.anomaly_loss_weight) if self.anomaly_loss_weight is not None else 1. ) }, 'anomaly_loss_weight': self.anomaly_loss_weight, 'anomaly_detection_strategy': self.anomaly_detection_strategy, }) # Log that the configuration has been updated. logger.debug( "Configuration for XTFT has been updated in get_config." ) return config
[docs] @classmethod def from_config(cls, config): """ Reconstruct model instance from configuration dictionary. Parameters ---------- config : dict Configuration dictionary generated by get_config() Returns ------- XTFT Fully reconstructed model instance Notes ----- - Handles special conversions: - Anomaly scores list -> numpy array - Quantile list restoration - Custom layer reconstruction - Maintains logger instance during reconstruction Example ------- >>> loaded_model = XTFT.from_config(json.load(open('model_config.json'))) """ logger.debug("Creating XTFT instance from configuration.") # Convert anomaly_scores from list back to a NumPy array, if present. if config["anomaly_config"].get("anomaly_scores") is not None: config["anomaly_config"]["anomaly_scores"] = np.array( config["anomaly_config"]["anomaly_scores"], dtype=np.float32 ) # Return a new instance created using the updated configuration. return cls(**config)
[docs] @Appender ( dedent( XTFT.__doc__.replace ('XTFT', 'SuperXTFT'), ), join='\n', ) @register_keras_serializable('fusionlab.nn.transformers', name="SuperXTFT") class SuperXTFT(XTFT): """ SuperXTFT: An enhanced version of XTFT with Variable Selection Networks (VSNs) and integrated Gate → Add & Norm → GRN pipeline in attention layers. """
[docs] def __init__( self, static_input_dim: int, dynamic_input_dim: int, future_input_dim: int, embed_dim: int = 32, forecast_horizon: int = 1, quantiles: Union[str, List[float], None] = None, max_window_size: int = 10, memory_size: int = 100, num_heads: int = 4, dropout_rate: float = 0.1, output_dim: int = 1, attention_units: int = 32, hidden_units: int = 64, lstm_units: int = 64, scales: Union[str, List[int], None] = None, multi_scale_agg: Optional[str] = 'auto', activation: str = 'relu', use_residuals: bool = True, use_batch_norm: bool = False, final_agg: str = 'last', anomaly_config: Optional[Dict[str, Any]] = None, anomaly_detection_strategy: Optional[str]=None, anomaly_loss_weight: float=1.0, **kw ): super().__init__( static_input_dim=static_input_dim, dynamic_input_dim=dynamic_input_dim, future_input_dim=future_input_dim, embed_dim=embed_dim, forecast_horizon=forecast_horizon, quantiles=quantiles, max_window_size=max_window_size, memory_size=memory_size, num_heads=num_heads, dropout_rate=dropout_rate, output_dim=output_dim, attention_units=attention_units, hidden_units=hidden_units, lstm_units=lstm_units, scales=scales, multi_scale_agg=multi_scale_agg, activation=activation, use_residuals=use_residuals, use_batch_norm=use_batch_norm, final_agg=final_agg, anomaly_config=anomaly_config, anomaly_detection_strategy=anomaly_detection_strategy, anomaly_loss_weight=anomaly_loss_weight, **kw, ) # Initialize Variable Selection Networks (VSNs) self.variable_selection_static = VariableSelectionNetwork( num_inputs=static_input_dim, units=hidden_units, dropout_rate=dropout_rate, use_time_distributed=False, activation=activation, use_batch_norm=use_batch_norm ) self.variable_selection_dynamic = VariableSelectionNetwork( num_inputs=dynamic_input_dim, units=hidden_units, dropout_rate=dropout_rate, use_time_distributed=True, activation=activation, use_batch_norm=use_batch_norm ) self.variable_future_covariate = VariableSelectionNetwork( num_inputs=future_input_dim, units=hidden_units, dropout_rate=dropout_rate, use_time_distributed=True, activation=activation, use_batch_norm=use_batch_norm ) # Add positional encoding self.positional_encoding = PositionalEncoding() # Initialize Gated Residual Networks (GRNs) for attention outputs self.grn_attention_hierarchical = GatedResidualNetwork( units=attention_units, dropout_rate=dropout_rate, activation=activation, use_batch_norm=use_batch_norm ) self.grn_attention_cross = GatedResidualNetwork( units=attention_units, dropout_rate=dropout_rate, activation=activation, use_batch_norm=use_batch_norm ) self.grn_memory_attention= GatedResidualNetwork( units=attention_units, dropout_rate=dropout_rate, activation=activation, use_batch_norm=use_batch_norm ) # Initialize Gate -> Add & Norm -> GRN pipeline for decoder outputs self.grn_decoder = GatedResidualNetwork( units=output_dim, dropout_rate=dropout_rate, activation=activation, use_batch_norm=use_batch_norm )
[docs] @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False, **kwargs): static_input, dynamic_input, future_input = validate_model_inputs( inputs=inputs, static_input_dim=self.static_input_dim, dynamic_input_dim=self.dynamic_input_dim, future_covariate_dim=self.future_input_dim, ) # Variable Selection for static, dynamic # inputs and future covariate selected_static = self.variable_selection_static( static_input, training=training) selected_dynamic = self.variable_selection_dynamic( dynamic_input, training=training) selected_future = self.variable_future_covariate( future_input, training=training) logger.debug( f"Selected Static Features Shape: {selected_static.shape}" ) logger.debug( f"Selected Dynamic Features Shape: {selected_dynamic.shape}" ) logger.debug( f"Selected Covariate Features Shape: {selected_future.shape}" ) # Proceed with the original XTFT forward pass using selected features # Normalize and process static features normalized_static = self.learned_normalization( selected_static, training=training ) logger.debug( f"Normalized Static Shape: {normalized_static.shape}" ) static_features = self.static_dense(normalized_static) if self.use_batch_norm: static_features = self.static_batch_norm( static_features, training=training ) logger.debug( "Static Features after BatchNorm Shape: " f"{static_features.shape}" ) static_features = self.static_dropout( static_features, training=training ) logger.debug( f"Static Features Shape: {static_features.shape}" ) # Embeddings for dynamic and future covariates using selected_dynamic # --- Prepare inputs for MultiModalEmbedding --- # dynamic_input is the reference for T_past (lookback period). # future_input needs its time dimension aligned to dynamic_input's. # In principle no need , just here for consistency logger.debug(" Aligning temporal inputs for MultiModalEmbedding...") _, selected_future_input_for_embedding = align_temporal_dimensions( tensor_ref=selected_dynamic, # Shape (B, T_past, D_dyn) tensor_to_align=selected_future, # Shape (B, T_future_total, D_fut) mode='slice_to_ref', # Slice future if longer name="future_input_for_mme" ) # future_input_for_embedding now has shape (B, T_past, D_fut) logger.debug( f" Dynamic for MME: {selected_dynamic.shape}, " f"Future for MME: {selected_future_input_for_embedding.shape}" ) embeddings = self.multi_modal_embedding( [selected_dynamic, selected_future_input_for_embedding], training=training ) logger.debug( f"Embeddings Shape: {embeddings.shape}" ) # Positional info embeddings = self.positional_encoding( embeddings, training=training ) logger.debug( f"Embeddings Shape after Positional Encoding: {embeddings.shape}" ) if self.use_residuals: embeddings = embeddings + self.residual_dense(embeddings) logger.debug( "Embeddings with Residuals Shape: " f"{embeddings.shape}" ) # Multi-scale LSTM outputs lstm_output = self.multi_scale_lstm( selected_dynamic, training=training ) # Handle multi_scale_agg as in XTFT lstm_features = aggregate_multiscale( lstm_output, mode= self.multi_scale_agg ) # Expand and tile lstm_features to match time steps time_steps = tf_shape(dynamic_input)[1] lstm_features = tf_expand_dims(lstm_features, axis=1) # (B, 1, features) lstm_features = tf_tile(lstm_features, [1, time_steps, 1]) # (B, T, features) logger.debug( f"LSTM Features Shape: {lstm_features.shape}" ) # Attention mechanisms with integrated GRNs hierarchical_att = self.hierarchical_attention( [selected_dynamic, selected_future_input_for_embedding], training=training ) logger.debug( f"Hierarchical Attention Shape: {hierarchical_att.shape}" ) # Apply Gate -> Add & Norm -> GRN pipeline to hierarchical attention hierarchical_att_grn = self.grn_attention_hierarchical( hierarchical_att, training=training ) logger.debug( f"Hierarchical Attention after GRN Shape: {hierarchical_att_grn.shape}" ) cross_attention_output = self.cross_attention( [selected_dynamic, embeddings], training=training ) logger.debug( f"Cross Attention Output Shape: {cross_attention_output.shape}" ) # Apply Gate -> Add & Norm -> GRN pipeline to cross attention cross_attention_grn = self.grn_attention_cross( cross_attention_output, training=training ) logger.debug( f"Cross Attention after GRN Shape: {cross_attention_grn.shape}" ) memory_attention_output = self.memory_augmented_attention( hierarchical_att_grn, training=training ) logger.debug( "Memory Augmented Attention Output Shape: " f"{memory_attention_output.shape}" ) # Apply Gate -> Add & Norm -> GRN pipeline to Memory attention memory_attention_grn = self.grn_memory_attention( hierarchical_att_grn, training=training ) logger.debug( f"Memory Attention after GRN Shape: {memory_attention_grn.shape}" ) # Combine all features static_features_expanded = tf_tile( tf_expand_dims(static_features, axis=1), [1, time_steps, 1] ) logger.debug( "Static Features Expanded Shape: " f"{static_features_expanded.shape}" ) combined_features = Concatenate()([ static_features_expanded, lstm_features, cross_attention_grn, hierarchical_att_grn, memory_attention_grn, ]) logger.debug( f"Combined Features Shape: {combined_features.shape}" ) attention_fusion_output = self.multi_resolution_attention_fusion( combined_features, training=training ) logger.debug( "Attention Fusion Output Shape: " f"{attention_fusion_output.shape}" ) # After computing attention_fusion_output if self.anomaly_detection_strategy == 'feature_based': attn_scores = self.anomaly_attention( query=attention_fusion_output, value=attention_fusion_output, training=training ) projected_attn = self.anomaly_projection(attn_scores) self.anomaly_scores = self.anomaly_scorer(projected_attn) elif self.anomaly_detection_strategy == 'from_config': self.anomaly_scores = validate_anomaly_scores( self.anomaly_config, self.forecast_horizon ) time_window_output = self.dynamic_time_window( attention_fusion_output, training=training ) logger.debug( f"Time Window Output Shape: {time_window_output.shape}" ) # Final Aggregation final_features = aggregate_time_window_output( time_window_output, self.final_agg ) # Decode the aggregated features decoder_outputs = self.multi_decoder( final_features, training=training ) logger.debug( f"Decoder Outputs Shape: {decoder_outputs.shape}" ) # Apply Gate -> Add & Norm -> GRN pipeline to decoder_outputs # Gate G = self.grn_decoder.gate_dense(decoder_outputs) # Add & Norm Z_norm = self.grn_decoder.layer_norm(decoder_outputs + G) # GRN Z_grn = self.grn_decoder(Z_norm, training=training) logger.debug( f"Decoder Outputs after GRN Pipeline Shape: {Z_grn.shape}" ) # Quantile Distribution Modeling predictions = self.quantile_distribution_modeling( Z_grn, training=training ) # Compute anomaly scores if configureg # Add anomaly loss if scores exist if self.anomaly_scores is not None: logger.debug( "Using Anomaly Scores from anomaly_config " f"Shape: {self.anomaly_scores.shape}" ) anomaly_loss = self.anomaly_loss_layer( self.anomaly_scores, tf_zeros_like(self.anomaly_scores) ) self.add_loss(self.anomaly_loss_weight * anomaly_loss) logger.debug( f"Anomaly Loss Computed and Added: {anomaly_loss}" ) logger.debug( f"Predictions Shape: {predictions.shape}" ) # Explicitly squeeze ONLY if quantiles were # requested AND output_dim is 1 if self.quantiles is not None and self.output_dim == 1: # Check if the tensor actually has 4 dimensions before squeezing if len(predictions.shape) == 4: final_output = tf_squeeze(predictions, axis=-1) logger.debug( f"Squeezed final quantile output dim (O=1)." f" New shape: {tf_shape(final_output)}" ) elif len(predictions.shape) == 3: # Already has shape (B, H, Q), no squeeze needed final_output = predictions else: # Unexpected shape logger.warning(f"Unexpected prediction shape before squeeze:" f" {predictions.shape}. Returning as is.") final_output = predictions elif self.quantiles is None: # Point forecast, ensure shape is (B, H, O) # (May not need specific handling # if output_layer gives correct shape) final_output = predictions logger.debug(f"TFT '{self.name}': Final returned output shape:" f" {tf_shape(final_output)}") return final_output
[docs] @classmethod def from_config(cls, config): logger.debug("Creating SuperXTFT instance from config.") return cls(**config)
XTFT.__doc__="""\ Extreme Temporal Fusion Transformer (XTFT) model for complex time series forecasting. XTF is an advanced architecture for time series forecasting, particularly suited to scenarios featuring intricate temporal patterns, multiple forecast horizons, and inherent uncertainties [1]_. By extending the original Temporal Fusion Transformer, XTFT incorporates additional modules and strategies that enhance its representational capacity, stability, and interpretability. See more in :ref:`User Guide <user_guide>`. {key_improvements} Parameters ---------- dynamic_input_dim : int Dimensionality of dynamic input features. These features vary over time steps and typically include historical observations of the target variable, and any time-dependent covariates such as past sales, weather variables, or sensor readings. A higher `dynamic_input_dim` enables the model to incorporate more complex patterns from a richer set of temporal signals. These features help the model understand seasonality, trends, and evolving conditions over time. future_input_dim : int Dimensionality of future known covariates. These are features known ahead of time for future predictions (e.g., holidays, promotions, scheduled events, or future weather forecasts). Increasing `future_input_dim` enhances the model’s ability to leverage external information about the future, improving the accuracy and stability of multi-horizon forecasts. static_input_dim : int Dimensionality of static input features (no time dimension). These features remain constant over time steps and provide global context or attributes related to the time series. For example, a store ID or geographic location. Increasing this dimension allows the model to utilize more contextual signals that do not vary with time. A larger `static_input_dim` can help the model specialize predictions for different entities or conditions and improve personalized forecasts. embed_dim : int, optional Dimension of feature embeddings. Default is ``32``. After variable transformations, inputs are projected into embeddings of size `embed_dim`. Larger embeddings can capture more nuanced relationships but may increase model complexity. A balanced choice prevents overfitting while ensuring the representation capacity is sufficient for complex patterns. forecast_horizon : int, optional Number of future time steps to predict. Default is ``1``. This parameter specifies how many steps ahead the model provides forecasts. For instance, `forecast_horizon=3` means the model predicts values for three future periods simultaneously. Increasing this allows multi-step forecasting, but may complicate learning if too large. quantiles : list of float or str, optional Quantiles to predict for probabilistic forecasting. For example, ``[0.1, 0.5, 0.9]`` indicates lower, median, and upper bounds. If set to ``'auto'``, defaults to ``[0.1, 0.5, 0.9]``. If `None`, the model makes deterministic predictions. Providing quantiles helps the model estimate prediction intervals and uncertainty, offering more informative and robust forecasts. max_window_size : int, optional Maximum dynamic time window size. Default is ``10``. Defines the length of the dynamic windowing mechanism that selects relevant recent time steps for modeling. A larger `max_window_size` enables the model to consider more historical data at once, potentially capturing longer-term patterns, but may also increase computational cost. memory_size : int, optional Size of the memory for memory-augmented attention. Default is ``100``. Introduces a fixed-size memory that the model can attend to, providing a global context or reference to distant past information. Larger `memory_size` can help the model recall patterns from further back in time, improving long-term forecasting stability. num_heads : int, optional Number of attention heads. Default is ``4``. Multi-head attention allows the model to attend to different representation subspaces of the input sequence. Increasing `num_heads` can improve model performance by capturing various aspects of the data, but also raises the computational complexity and the number of parameters. dropout_rate : float, optional Dropout rate for regularization. Default is ``0.1``. Controls the fraction of units dropped out randomly during training. Higher values can prevent overfitting but may slow convergence. A small to moderate `dropout_rate` (e.g. 0.1 to 0.3) is often a good starting point. output_dim : int, optional Dimensionality of the output. Default is ``1``. Determines how many target variables are predicted at each forecast horizon. For univariate forecasting, `output_dim=1` is typical. For multi-variate forecasting, set a larger value to predict multiple targets simultaneously. anomaly_loss_weight : float, optional Weight of the anomaly loss term. Default is ``.1``. attention_units : int, optional Number of units in attention layers. Default is ``32``. Controls the dimensionality of internal representations in attention mechanisms. More `attention_units` can allow the model to represent more complex dependencies, but may also increase risk of overfitting and computation. hidden_units : int, optional Number of units in hidden layers. Default is ``64``. Influences the capacity of various dense layers within the model, such as those processing static features or for residual connections. More units allow modeling more intricate functions, but can lead to overfitting if not regularized. lstm_units : int or None, optional Number of units in LSTM layers. Default is ``64``. If `None`, LSTM layers may be disabled or replaced with another mechanism. Increasing `lstm_units` improves the model’s ability to capture temporal dependencies, but also raises computational cost and potential overfitting. scales : list of int, str or None, optional Scales for multi-scale LSTM. If ``'auto'``, defaults are chosen internally. This parameter configures multiple LSTMs to operate at different temporal resolutions. For example, `[1, 7, 30]` might represent daily, weekly, and monthly scales. Multi-scale modeling can enhance the model’s understanding of hierarchical time structures and seasonalities. multi_scale_agg : str or None, optional Aggregation method for multi-scale outputs. Options: ``'last'``, ``'average'``, ``'flatten'``, ``'auto'``. If `None`, no special aggregation is applied. This parameter determines how the multiple scales’ outputs are combined. For instance, `average` can produce a more stable representation by averaging across scales, while `flatten` preserves all scale information in a concatenated form. activation : str or callable, optional Activation function. Default is ``'relu'``. Common choices include ``'tanh'``, ``'elu'``, or a custom callable. The choice of activation affects the model’s nonlinearity and can influence convergence speed and final accuracy. For complex datasets, experimenting with different activations may yield better results. use_residuals : bool, optional Whether to use residual connections. Default is ``True``. Residuals help in stabilizing and speeding up training by allowing gradients to flow more easily through the model and mitigating vanishing gradients. They also enable deeper model architectures without significant performance degradation. use_batch_norm : bool, optional Whether to use batch normalization. Default is ``False``. Batch normalization can accelerate training by normalizing layer inputs, reducing internal covariate shift. It often makes model training more stable and can improve convergence, especially in deeper architectures. However, it adds complexity and may not always be beneficial. final_agg : str, optional Final aggregation of the time window. Options: ``'last'``, ``'average'``, ``'flatten'``. Default is ``'last'``. Determines how the time-windowed representations are reduced into a final vector before decoding into forecasts. For example, `last` takes the most recent time step's feature vector, while `average` merges information across the entire window. Choosing a suitable aggregation can influence forecast stability and sensitivity to recent or aggregate patterns. anomaly_config : dict, optional Configuration dictionary for anomaly detection. It may contain the following keys: - ``'anomaly_scores'`` : array-like, optional Precomputed anomaly scores tensor of shape `(batch_size, forecast_horizon)`. If not provided, anomaly loss will not be applied. - ``'anomaly_loss_weight'`` : float, optional Weight for the anomaly loss in the total loss computation. Balances the contribution of anomaly detection against the primary forecasting task. A higher value emphasizes identifying and penalizing anomalies, potentially improving robustness to irregularities in the data, while a lower value prioritizes general forecasting performance. If not provided, anomaly loss will not be applied. **Behavior:** If `anomaly_config` is `None`, both `'anomaly_scores'` and `'anomaly_loss_weight'` default to `None`, and anomaly loss is disabled. This means the model will perform forecasting without considering any anomaly detection mechanisms. **Examples:** - **Without Anomaly Detection:** ```python model = XTFT( static_input_dim=10, dynamic_input_dim=45, future_input_dim=5, anomaly_config=None, ... ) ``` - **With Anomaly Detection:** ```python import tensorflow as tf # Define precomputed anomaly scores precomputed_anomaly_scores = tf.random.normal((batch_size, forecast_horizon)) # Create anomaly_config dictionary anomaly_config = {{ 'anomaly_scores': precomputed_anomaly_scores, 'anomaly_loss_weight': 1.0 }} # Initialize the model with anomaly_config model = XTFT( static_input_dim=10, dynamic_input_dim=45, future_input_dim=5, anomaly_config=anomaly_config, ... ) ``` **kw : dict Additional keyword arguments passed to the model. These may include configuration options for layers, optimizers, or training routines not covered by the parameters above. {methods} {key_functions} Examples -------- >>> import os >>> import tensorflow as tf >>> import pandas as pd >>> import numpy as np >>> from fusionlab.nn.transformers import XTFT >>> from fusionlab.nn.losses import combined_quantile_loss >>> from fusionlab.nn.utils import generate_forecast >>> >>> # Create a dummy training DataFrame with a date column, >>> # dynamic features "feat1", "feat2", static feature "stat1", >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> # Prepare a dummy XTFT model with example parameters. >>> # Note: The model expects the following input shapes: >>> # - X_static: (n_samples, static_input_dim) >>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim) >>> # - X_future: (n_samples, time_steps, future_input_dim) >>> # We just want to test the saved model >>> data_path =r'J:\test_saved_models' >>> early_stopping = tf.keras.callbacks.EarlyStopping( ... monitor = 'val_loss', ... patience = 5, ... restore_best_weights = True ... ) >>> model_checkpoint = tf.keras.callbacks.ModelCheckpoint( ... os.path.join( data_path, 'dummy_model'), ... monitor = 'val_loss', ... save_best_only = True, ... save_weights_only = False, # Save entire model ... verbose = 1 ... ) >>> # Create a dummy DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D") >>> data = { ... "date": date_rng, ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "price": np.random.rand(60) ... } >>> df = pd.DataFrame(data) >>> df.head(5) >>> >>> >>> # Split the DataFrame into training and test sets. >>> # Training data: dates before 2020-02-01 >>> # Test data: dates from 2020-02-01 onward. >>> train_df = df[df["date"] < "2020-02-01"].copy() >>> test_df = df[df["date"] >= "2020-02-01"].copy() >>> >>> # Create dummy input arrays for model fitting. >>> # Assume time_steps = 3. >>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1) >>> X_dynamic = np.random.rand(len(train_df), 3, 2) >>> X_future = np.random.rand(len(train_df), 3, 1) >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # For the provided future feature ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> # build the model >>> _=my_model([X_static, X_dynamic, X_future]) # ... input_shape=[ # ... (None, X_static.shape[1]), # ... (None, X_dynamic.shape[1], X_dynamic.shape[2]), # ... (None, X_future.shape[1], X_future.shape[2]) # ... ] # ... ) >>> loss_fn = combined_quantile_loss(my_model.quantiles) >>> my_model.compile(optimizer="adam", loss=loss_fn) >>> >>> # Fit the model on the training data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=10, ... batch_size=8, ... validation_split= 0.2, ... callbacks = [early_stopping, model_checkpoint] ... ) >>> my_model.save(os.path.join(data_path, 'dummy_model.keras')) Epoch 9/10 4/4 [==============================] - 0s 4ms/step - loss: 0.0958 Epoch 10/10 4/4 [==============================] - 0s 5ms/step - loss: 0.1009 Out[10]: <keras.src.callbacks.History at 0x1c7a9114c10> >>> y_predictions=my_model.predict([X_static, X_dynamic, X_future]) 1/1 [==============================] - 1s 640ms/step >>> print(y_predictions.shape) (31, 5, 3, 1) >>> # now let reload the model 'dummy_model' and check whether >>> # it's successfully releaded. >>> test_model = tf.keras.models.load_model (os.path.join( data_path, 'dummy_model.keras')) >>> test_model See Also -------- fusionlab.nn.tft.TemporalFusionTransformer : The original TFT model for comparison. MultiHeadAttention : Keras layer for multi-head attention. LSTM : Keras LSTM layer for sequence modeling. References ---------- .. [1] Wang, X., et al. (2021). "Enhanced Temporal Fusion Transformer for Time Series Forecasting." International Journal of Forecasting, 37(3), 1234-1245. """