fusionlab.nn.models.XTFT¶
- class fusionlab.nn.models.XTFT[source]¶
Bases:
Model,NNLearnerExtreme Temporal Fusion Transformer (XTFT) model for complex time series forecasting.
XTF is an advanced architecture for time series forecasting, particularly suited to scenarios featuring intricate temporal patterns, multiple forecast horizons, and inherent uncertainties [1]. By extending the original Temporal Fusion Transformer, XTFT incorporates additional modules and strategies that enhance its representational capacity, stability, and interpretability.
See more in User Guide.
{key_improvements}
- Parameters:
dynamic_input_dim (
int) – Dimensionality of dynamic input features. These features vary over time steps and typically include historical observations of the target variable, and any time-dependent covariates such as past sales, weather variables, or sensor readings. A higher dynamic_input_dim enables the model to incorporate more complex patterns from a richer set of temporal signals. These features help the model understand seasonality, trends, and evolving conditions over time.future_input_dim (
int) – Dimensionality of future known covariates. These are features known ahead of time for future predictions (e.g., holidays, promotions, scheduled events, or future weather forecasts). Increasing future_input_dim enhances the model’s ability to leverage external information about the future, improving the accuracy and stability of multi-horizon forecasts.static_input_dim (
int) – Dimensionality of static input features (no time dimension). These features remain constant over time steps and provide global context or attributes related to the time series. For example, a store ID or geographic location. Increasing this dimension allows the model to utilize more contextual signals that do not vary with time. A larger static_input_dim can help the model specialize predictions for different entities or conditions and improve personalized forecasts.embed_dim (
int, optional) – Dimension of feature embeddings. Default is32. After variable transformations, inputs are projected into embeddings of size embed_dim. Larger embeddings can capture more nuanced relationships but may increase model complexity. A balanced choice prevents overfitting while ensuring the representation capacity is sufficient for complex patterns.forecast_horizon (
int, optional) – Number of future time steps to predict. Default is1. This parameter specifies how many steps ahead the model provides forecasts. For instance, forecast_horizon=3 means the model predicts values for three future periods simultaneously. Increasing this allows multi-step forecasting, but may complicate learning if too large.quantiles (
listoffloatorstr, optional) – Quantiles to predict for probabilistic forecasting. For example,[0.1, 0.5, 0.9]indicates lower, median, and upper bounds. If set to'auto', defaults to[0.1, 0.5, 0.9]. If None, the model makes deterministic predictions. Providing quantiles helps the model estimate prediction intervals and uncertainty, offering more informative and robust forecasts.max_window_size (
int, optional) – Maximum dynamic time window size. Default is10. Defines the length of the dynamic windowing mechanism that selects relevant recent time steps for modeling. A larger max_window_size enables the model to consider more historical data at once, potentially capturing longer-term patterns, but may also increase computational cost.memory_size (
int, optional) – Size of the memory for memory-augmented attention. Default is100. Introduces a fixed-size memory that the model can attend to, providing a global context or reference to distant past information. Larger memory_size can help the model recall patterns from further back in time, improving long-term forecasting stability.num_heads (
int, optional) – Number of attention heads. Default is4. Multi-head attention allows the model to attend to different representation subspaces of the input sequence. Increasing num_heads can improve model performance by capturing various aspects of the data, but also raises the computational complexity and the number of parameters.dropout_rate (
float, optional) – Dropout rate for regularization. Default is0.1. Controls the fraction of units dropped out randomly during training. Higher values can prevent overfitting but may slow convergence. A small to moderate dropout_rate (e.g. 0.1 to 0.3) is often a good starting point.output_dim (
int, optional) – Dimensionality of the output. Default is1. Determines how many target variables are predicted at each forecast horizon. For univariate forecasting, output_dim=1 is typical. For multi-variate forecasting, set a larger value to predict multiple targets simultaneously.
- anomaly_loss_weightfloat, optional
Weight of the anomaly loss term. Default is
.1.- attention_unitsint, optional
Number of units in attention layers. Default is
32. Controls the dimensionality of internal representations in attention mechanisms. More attention_units can allow the model to represent more complex dependencies, but may also increase risk of overfitting and computation.- hidden_unitsint, optional
Number of units in hidden layers. Default is
64. Influences the capacity of various dense layers within the model, such as those processing static features or for residual connections. More units allow modeling more intricate functions, but can lead to overfitting if not regularized.- lstm_unitsint or None, optional
Number of units in LSTM layers. Default is
64. If None, LSTM layers may be disabled or replaced with another mechanism. Increasing lstm_units improves the model’s ability to capture temporal dependencies, but also raises computational cost and potential overfitting.- scaleslist of int, str or None, optional
Scales for multi-scale LSTM. If
'auto', defaults are chosen internally. This parameter configures multiple LSTMs to operate at different temporal resolutions. For example, [1, 7, 30] might represent daily, weekly, and monthly scales. Multi-scale modeling can enhance the model’s understanding of hierarchical time structures and seasonalities.- multi_scale_aggstr or None, optional
Aggregation method for multi-scale outputs. Options:
'last','average','flatten','auto'. If None, no special aggregation is applied. This parameter determines how the multiple scales’ outputs are combined. For instance, average can produce a more stable representation by averaging across scales, while flatten preserves all scale information in a concatenated form.- activationstr or callable, optional
Activation function. Default is
'relu'. Common choices include'tanh','elu', or a custom callable. The choice of activation affects the model’s nonlinearity and can influence convergence speed and final accuracy. For complex datasets, experimenting with different activations may yield better results.- use_residualsbool, optional
Whether to use residual connections. Default is
True. Residuals help in stabilizing and speeding up training by allowing gradients to flow more easily through the model and mitigating vanishing gradients. They also enable deeper model architectures without significant performance degradation.- use_batch_normbool, optional
Whether to use batch normalization. Default is
False. Batch normalization can accelerate training by normalizing layer inputs, reducing internal covariate shift. It often makes model training more stable and can improve convergence, especially in deeper architectures. However, it adds complexity and may not always be beneficial.- final_aggstr, optional
Final aggregation of the time window. Options:
'last','average','flatten'. Default is'last'. Determines how the time-windowed representations are reduced into a final vector before decoding into forecasts. For example, last takes the most recent time step’s feature vector, while average merges information across the entire window. Choosing a suitable aggregation can influence forecast stability and sensitivity to recent or aggregate patterns.- anomaly_configdict, optional
Configuration dictionary for anomaly detection. It may contain the following keys:
'anomaly_scores'array-like, optionalPrecomputed anomaly scores tensor of shape (batch_size, forecast_horizon). If not provided, anomaly loss will not be applied.
'anomaly_loss_weight'float, optionalWeight for the anomaly loss in the total loss computation. Balances the contribution of anomaly detection against the primary forecasting task. A higher value emphasizes identifying and penalizing anomalies, potentially improving robustness to irregularities in the data, while a lower value prioritizes general forecasting performance. If not provided, anomaly loss will not be applied.
Behavior: If anomaly_config is None, both ‘anomaly_scores’ and ‘anomaly_loss_weight’ default to None, and anomaly loss is disabled. This means the model will perform forecasting without considering any anomaly detection mechanisms.
Examples:
- With Anomaly Detection:
```python import tensorflow as tf
# Define precomputed anomaly scores precomputed_anomaly_scores = tf.random.normal((batch_size, forecast_horizon))
# Create anomaly_config dictionary anomaly_config = {{
‘anomaly_scores’: precomputed_anomaly_scores, ‘anomaly_loss_weight’: 1.0
}}
# Initialize the model with anomaly_config model = XTFT(
static_input_dim=10, dynamic_input_dim=45, future_input_dim=5, anomaly_config=anomaly_config, …
- **kwdict
Additional keyword arguments passed to the model. These may include configuration options for layers, optimizers, or training routines not covered by the parameters above.
{methods}
{key_functions}
Examples
>>> import os >>> import tensorflow as tf >>> import pandas as pd >>> import numpy as np >>> from fusionlab.nn.transformers import XTFT >>> from fusionlab.nn.losses import combined_quantile_loss >>> from fusionlab.nn.utils import generate_forecast >>> >>> # Create a dummy training DataFrame with a date column, >>> # dynamic features "feat1", "feat2", static feature "stat1", >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=50, freq="D") >>> train_df = pd.DataFrame({ ... "date": date_rng, ... "feat1": np.random.rand(50), ... "feat2": np.random.rand(50), ... "stat1": np.random.rand(50), ... "price": np.random.rand(50) ... }) >>> # Prepare a dummy XTFT model with example parameters. >>> # Note: The model expects the following input shapes: >>> # - X_static: (n_samples, static_input_dim) >>> # - X_dynamic: (n_samples, time_steps, dynamic_input_dim) >>> # - X_future: (n_samples, time_steps, future_input_dim) >>> # We just want to test the saved model >>> data_path =r'J: est_saved_models' >>> early_stopping = tf.keras.callbacks.EarlyStopping( ... monitor = 'val_loss', ... patience = 5, ... restore_best_weights = True ... ) >>> model_checkpoint = tf.keras.callbacks.ModelCheckpoint( ... os.path.join( data_path, 'dummy_model'), ... monitor = 'val_loss', ... save_best_only = True, ... save_weights_only = False, # Save entire model ... verbose = 1 ... ) >>> # Create a dummy DataFrame with a date column, >>> # two dynamic features ("feat1", "feat2"), one static feature ("stat1"), >>> # and target "price". >>> date_rng = pd.date_range(start="2020-01-01", periods=60, freq="D") >>> data = { ... "date": date_rng, ... "feat1": np.random.rand(60), ... "feat2": np.random.rand(60), ... "stat1": np.random.rand(60), ... "price": np.random.rand(60) ... } >>> df = pd.DataFrame(data) >>> df.head(5) >>> >>> >>> # Split the DataFrame into training and test sets. >>> # Training data: dates before 2020-02-01 >>> # Test data: dates from 2020-02-01 onward. >>> train_df = df[df["date"] < "2020-02-01"].copy() >>> test_df = df[df["date"] >= "2020-02-01"].copy() >>> >>> # Create dummy input arrays for model fitting. >>> # Assume time_steps = 3. >>> X_static = train_df[["stat1"]].values # Shape: (n_train, 1) >>> X_dynamic = np.random.rand(len(train_df), 3, 2) >>> X_future = np.random.rand(len(train_df), 3, 1) >>> # Create dummy target output from "price". >>> y_array = train_df["price"].values.reshape(len(train_df), 1, 1) >>> >>> # Instantiate a dummy XTFT model. >>> my_model = XTFT( ... static_input_dim=1, # "stat1" ... dynamic_input_dim=2, # "feat1" and "feat2" ... future_input_dim=1, # For the provided future feature ... forecast_horizon=5, # Forecasting 5 periods ahead ... quantiles=[0.1, 0.5, 0.9], ... embed_dim=16, ... max_window_size=3, ... memory_size=50, ... num_heads=2, ... dropout_rate=0.1, ... lstm_units=32, ... attention_units=32, ... hidden_units=16 ... ) >>> # build the model >>> _=my_model([X_static, X_dynamic, X_future]) # ... input_shape=[ # ... (None, X_static.shape[1]), # ... (None, X_dynamic.shape[1], X_dynamic.shape[2]), # ... (None, X_future.shape[1], X_future.shape[2]) # ... ] # ... ) >>> loss_fn = combined_quantile_loss(my_model.quantiles) >>> my_model.compile(optimizer="adam", loss=loss_fn) >>> >>> # Fit the model on the training data. >>> my_model.fit( ... x=[X_static, X_dynamic, X_future], ... y=y_array, ... epochs=10, ... batch_size=8, ... validation_split= 0.2, ... callbacks = [early_stopping, model_checkpoint] ... ) >>> my_model.save(os.path.join(data_path, 'dummy_model.keras')) Epoch 9/10 4/4 [==============================] - 0s 4ms/step - loss: 0.0958 Epoch 10/10 4/4 [==============================] - 0s 5ms/step - loss: 0.1009 Out[10]: <keras.src.callbacks.History at 0x1c7a9114c10>
>>> y_predictions=my_model.predict([X_static, X_dynamic, X_future]) 1/1 [==============================] - 1s 640ms/step >>> print(y_predictions.shape) (31, 5, 3, 1) >>> # now let reload the model 'dummy_model' and check whether >>> # it's successfully releaded. >>> test_model = tf.keras.models.load_model (os.path.join( data_path, 'dummy_model.keras')) >>> test_model
See also
fusionlab.nn.tft.TemporalFusionTransformerThe original TFT model for comparison.
MultiHeadAttentionKeras layer for multi-head attention.
LSTMKeras LSTM layer for sequence modeling.
References
- __init__(static_input_dim, dynamic_input_dim, future_input_dim, embed_dim=32, forecast_horizon=1, quantiles=None, max_window_size=10, memory_size=100, num_heads=4, dropout_rate=0.1, output_dim=1, attention_units=32, hidden_units=64, lstm_units=64, scales=None, multi_scale_agg=None, activation='relu', use_residuals=True, use_batch_norm=False, final_agg='last', anomaly_config=None, anomaly_detection_strategy=None, anomaly_loss_weight=0.1, **kw)[source]¶
- Parameters:
static_input_dim (int)
dynamic_input_dim (int)
future_input_dim (int)
embed_dim (int)
forecast_horizon (int)
quantiles (str | List[float] | None)
max_window_size (int)
memory_size (int)
num_heads (int)
dropout_rate (float)
output_dim (int)
attention_units (int)
hidden_units (int)
lstm_units (int)
scales (str | List[int] | None)
multi_scale_agg (str | None)
activation (str)
use_residuals (bool)
use_batch_norm (bool)
final_agg (str)
anomaly_config (Dict[str, Any] | None)
anomaly_detection_strategy (str | None)
anomaly_loss_weight (float)
Methods
__init__(static_input_dim, ...[, embed_dim, ...])add_loss(losses, **kwargs)Add loss tensor(s), potentially dependent on layer inputs.
add_metric(value[, name])Adds metric tensor to the layer.
add_update(updates)Add update op(s), potentially dependent on layer inputs.
add_variable(*args, **kwargs)Deprecated, do NOT use! Alias for add_weight.
add_weight([name, shape, dtype, ...])Adds a new variable to the layer.
build(input_shape)Builds the model based on input shapes received.
build_from_config(config)Builds the layer's states with the supplied config dict.
call(inputs[, training])Forward pass of the XTFT model.
compile(optimizer[, loss])Compile the XTFT model, allowing an explicit user-specified loss to override the defaults.
compile_from_config(config)Compiles the model with the information given in config.
compute_loss([x, y, y_pred, sample_weight])Compute the total loss, validate it, and return it.
compute_mask(inputs[, mask])Computes an output mask tensor.
compute_metrics(x, y, y_pred, sample_weight)Update metric states and collect all metrics to be returned.
compute_output_shape(input_shape)Computes the output shape of the layer.
compute_output_signature(input_signature)Compute the output tensor signature of the layer based on the inputs.
count_params()Count the total number of scalars composing the weights.
evaluate([x, y, batch_size, verbose, ...])Returns the loss value & metrics values for the model in test mode.
evaluate_generator(generator[, steps, ...])Evaluates the model on a data generator.
export(filepath)Create a SavedModel artifact for inference (e.g. via TF-Serving).
finalize_state()Finalizes the layers state after updating layer weights.
fit([x, y, batch_size, epochs, verbose, ...])Trains the model for a fixed number of epochs (dataset iterations).
fit_generator(generator[, steps_per_epoch, ...])Fits the model on data yielded batch-by-batch by a Python generator.
from_config(config)Reconstruct model instance from configuration dictionary.
get_build_config()Returns a dictionary with the layer's input shape.
get_compile_config()Returns a serialized config with information for compiling the model.
Get serialization configuration for model saving/loading.
get_input_at(node_index)Retrieves the input tensor(s) of a layer at a given node.
get_input_mask_at(node_index)Retrieves the input mask tensor(s) of a layer at a given node.
get_input_shape_at(node_index)Retrieves the input shape(s) of a layer at a given node.
get_layer([name, index])Retrieves a layer based on either its name (unique) or index.
get_metrics_result()Returns the model's metrics values as a dict.
get_output_at(node_index)Retrieves the output tensor(s) of a layer at a given node.
get_output_mask_at(node_index)Retrieves the output mask tensor(s) of a layer at a given node.
get_output_shape_at(node_index)Retrieves the output shape(s) of a layer at a given node.
get_params([deep])Get the parameters for this learner.
get_weight_paths()Retrieve all the variables and their paths for the model.
get_weights()Retrieves the weights of the model.
help(**kwargs)load(file_path[, format])Load the learner's state from a specified file in the desired format.
load_own_variables(store)Loads the state of the layer.
load_weights(filepath[, skip_mismatch, ...])Loads all layer weights from a saved files.
make_predict_function([force])Creates a function that executes one step of inference.
make_test_function([force])Creates a function that executes one step of evaluation.
make_train_function([force])Creates a function that executes one step of training.
predict(x[, batch_size, verbose, steps, ...])Generates output predictions for the input samples.
predict_generator(generator[, steps, ...])Generates predictions for the input samples from a data generator.
predict_on_batch(x)Returns predictions for a single batch of samples.
predict_step(data)The logic for one inference step.
reset_metrics()Resets the state of all the metrics in the model.
reset_states()save(filepath[, overwrite, save_format])Saves a model as a TensorFlow SavedModel or HDF5 file.
save_own_variables(store)Saves the state of the layer.
save_spec([dynamic_batch])Returns the tf.TensorSpec of call args as a tuple (args, kwargs).
save_weights(filepath[, overwrite, ...])Saves all layer weights.
set_params(**params)Set the parameters of this learner.
set_weights(weights)Sets the weights of the layer, from NumPy arrays.
summary([line_length, positions, print_fn, ...])Prints a string summary of the network.
test_on_batch(x[, y, sample_weight, ...])Test the model on a single batch of samples.
test_step(data)The logic for one evaluation step.
to_json(**kwargs)Returns a JSON string containing the network configuration.
to_yaml(**kwargs)Returns a yaml string containing the network configuration.
train_on_batch(x[, y, sample_weight, ...])Runs a single gradient update on a single batch of data.
train_step(data)Custom training step with anomaly detection strategy handling.
with_name_scope(method)Decorator to automatically enter the module name scope.
Attributes
activity_regularizerOptional regularizer function for the output of this layer.
autotune_steps_per_executionSettable property to enable tuning for steps_per_execution
compute_dtypeThe dtype of the layer's computations.
distribute_reduction_methodThe method employed to reduce per-replica values during training.
distribute_strategyThe tf.distribute.Strategy this model was created under.
dtypeThe dtype of the layer weights.
dtype_policyThe dtype policy associated with this layer.
dynamicWhether the layer is dynamic (eager-only); set in the constructor.
inbound_nodesReturn Functional API nodes upstream of this layer.
inputRetrieves the input tensor(s) of a layer.
input_maskRetrieves the input mask tensor(s) of a layer.
input_shapeRetrieves the input shape(s) of a layer.
input_specInputSpec instance(s) describing the input format for this layer.
jit_compileSpecify whether to compile the model with XLA.
layerslossesList of losses added using the add_loss() API.
metricsReturn metrics added using compile() or add_metric().
metrics_namesReturns the model's display labels for all outputs.
nameName of the layer (string), set in the constructor.
name_scopeReturns a tf.name_scope instance for this class.
non_trainable_variablesSequence of non-trainable variables owned by this module and its submodules.
non_trainable_weightsList of all non-trainable weights tracked by this layer.
outbound_nodesReturn Functional API nodes downstream of this layer.
outputRetrieves the output tensor(s) of a layer.
output_maskRetrieves the output mask tensor(s) of a layer.
output_shapeRetrieves the output shape(s) of a layer.
run_eagerlySettable attribute indicating whether the model should run eagerly.
state_updatesDeprecated, do NOT use!
statefulsteps_per_executionSettable `steps_per_execution variable. Requires a compiled model.
submodulesSequence of all sub-modules.
supports_maskingWhether this layer supports computing a mask using compute_mask.
trainabletrainable_variablesSequence of trainable variables owned by this module and its submodules.
trainable_weightsList of all trainable weights tracked by this layer.
updatesvariable_dtypeAlias of Layer.dtype, the dtype of the weights.
variablesReturns the list of all layer variables/weights.
weightsReturns the list of all layer variables/weights.
- __init__(static_input_dim, dynamic_input_dim, future_input_dim, embed_dim=32, forecast_horizon=1, quantiles=None, max_window_size=10, memory_size=100, num_heads=4, dropout_rate=0.1, output_dim=1, attention_units=32, hidden_units=64, lstm_units=64, scales=None, multi_scale_agg=None, activation='relu', use_residuals=True, use_batch_norm=False, final_agg='last', anomaly_config=None, anomaly_detection_strategy=None, anomaly_loss_weight=0.1, **kw)[source]¶
- Parameters:
static_input_dim (int)
dynamic_input_dim (int)
future_input_dim (int)
embed_dim (int)
forecast_horizon (int)
quantiles (str | List[float] | None)
max_window_size (int)
memory_size (int)
num_heads (int)
dropout_rate (float)
output_dim (int)
attention_units (int)
hidden_units (int)
lstm_units (int)
scales (str | List[int] | None)
multi_scale_agg (str | None)
activation (str)
use_residuals (bool)
use_batch_norm (bool)
final_agg (str)
anomaly_config (Dict[str, Any] | None)
anomaly_detection_strategy (str | None)
anomaly_loss_weight (float)
- call(inputs, training=False, **kwargs)[source]¶
Forward pass of the XTFT model.
- Parameters:
inputs (
tupleorlist) – Input data containing three elements: 1. Static features (batch_size, static_input_dim) 2. Dynamic historical features (batch_size, time_steps, dynamic_input_dim) 3. Future covariates (batch_size, horizon, future_input_dim)training (
bool, optional) – Whether the model is in training mode, by default False**kwargs – Additional keyword arguments
- Returns:
Predictions tensor of shape: - (batch_size, horizon, len(quantiles)) if quantiles specified - (batch_size, horizon, output_dim) otherwise
- Return type:
tf.Tensor- Raises:
ValueError – If input validation fails through validate_xtft_inputs
Notes
Handles three types of anomaly detection strategies: 1. ‘feature_based’: Generates scores from attention mechanisms 2. ‘prediction_based’: Handled in loss function 3. ‘from_config’: Uses precomputed anomaly scores
Implements multi-scale temporal processing with: - Positional encoding - Hierarchical attention - Memory-augmented attention - Dynamic time windowing
- compile(optimizer, loss=None, **kws)[source]¶
Compile the XTFT model, allowing an explicit user-specified loss to override the defaults.
If the user provides a loss (loss=…), it is used regardless of quantiles or anomaly scores. Otherwise, the method uses the following logic:
If
self.quantilesis None, defaults to “mse”(or the user-suppliedloss).If
self.quantilesis not None, uses a quantile-based loss. Ifanomaly_scoresis present, a total loss is used that adds anomaly loss on top.
See also:¶
fusionlab.nn.losses.combined_quantile_loss fusionlab.nn.losses.combined_total_loss
- train_step(data)[source]¶
Custom training step with anomaly detection strategy handling.
- Parameters:
data (
tuple/tf.data.Dataset) – Training data containing: - For prediction-based strategy: (x, y) pairs - Other strategies: Standard Keras-compatible format- Returns:
Metric results dictionary
- Return type:
dict
Notes
Special handling for prediction-based anomaly detection: - Requires explicit (x, y) pairs - Validates y_true integrity - Falls back to standard training if data format invalid
For other strategies, uses native Keras training logic
- Raises:
Warning (logged) –
For missing y_true in prediction-based mode - For invalid/nan values in y_true
Example
>>> model.compile(...) >>> model.fit(dataset, epochs=10)
- get_config()[source]¶
Get serialization configuration for model saving/loading.
- Returns:
Complete configuration dictionary containing: - Model architecture parameters - Anomaly detection configuration - Component hyperparameters - Training configuration
- Return type:
dict
Notes
Handles special cases for: - Quantile list serialization - Numpy array conversion for anomaly scores - Custom layer configurations
Logs configuration changes via model logger
Example
>>> config = model.get_config() >>> json.dump(config, open('model_config.json', 'w'))
- classmethod from_config(config)[source]¶
Reconstruct model instance from configuration dictionary.
- Parameters:
config (
dict) – Configuration dictionary generated by get_config()- Returns:
Fully reconstructed model instance
- Return type:
Notes
Handles special conversions: - Anomaly scores list -> numpy array - Quantile list restoration - Custom layer reconstruction
Maintains logger instance during reconstruction
Example
>>> loaded_model = XTFT.from_config(json.load(open('model_config.json')))
- help(**kwargs)¶
- my_params = XTFT( static_input_dim, dynamic_input_dim, future_input_dim, embed_dim=32, forecast_horizon=1, quantiles=None, max_window_size=10, memory_size=100, num_heads=4, dropout_rate=0.1, output_dim=1, attention_units=32, hidden_units=64, lstm_units=64, scales=None, multi_scale_agg=None, activation='relu', use_residuals=True, use_batch_norm=False, final_agg='last', anomaly_config=None, anomaly_detection_strategy=None, anomaly_loss_weight=0.1 )¶