fusionlab.nn.transformers.TFT

class fusionlab.nn.transformers.TFT[source]

Bases: Model, NNLearner

Temporal Fusion Transformer (TFT) required static, dynamic(past) and future inputs.

This class implements the Temporal Fusion Transformer (TFT) architecture, closely following the structure described in the original paper [Lim21]. It is designed for multi-horizon time series forecasting and explicitly requires static covariates, dynamic (historical) covariates, and known future covariates as inputs.

Compared to more flexible implementations, this version mandates all input types, simplifying the internal input handling logic. It incorporates key TFT components like Variable Selection Networks (VSNs), Gated Residual Networks (GRNs) for static context generation and feature processing, LSTM encoding, static enrichment, interpretable multi-head attention, and position-wise feedforward layers.

Parameters:
  • dynamic_input_dim (int) – The total number of features present in the dynamic (past) input tensor. These are features that vary across the lookback time steps.

  • static_input_dim (int) – The total number of features present in the static (time-invariant) input tensor. These features provide context that does not change over time for a given series.

  • future_input_dim (int) – The total number of features present in the known future input tensor. These features provide information about future events or conditions known at the time of prediction.

  • hidden_units (int, default 32) – The main dimensionality of the hidden layers used throughout the network, including VSN outputs, GRN hidden states, enrichment layers, and attention mechanisms.

  • num_heads (int, default 4) – Number of attention heads used in the TemporalAttentionLayer. More heads allow attending to different representation subspaces.

  • dropout_rate (float, default 0.1) – Dropout rate applied to non-recurrent connections in LSTMs, VSNs, GRNs, and Attention layers. Value between 0 and 1.

  • recurrent_dropout_rate (float, default 0.0) – Dropout rate applied specifically to the recurrent connections within the LSTM layers. Helps regularize the recurrent state updates. Value between 0 and 1. Note: May impact performance on GPUs.

  • forecast_horizon (int, default 1) – The number of future time steps the model is trained to predict simultaneously (multi-horizon forecasting).

  • quantiles (list[float], optional, default None) –

    A list of quantiles (e.g., [0.1, 0.5, 0.9]) for probabilistic forecasting. * If provided, the model outputs predictions for each specified

    quantile, and the combined_quantile_loss() should typically be used for training.

    • If None, the model performs point forecasting (outputting a

      single value per step), typically trained with MSE loss.

  • activation (str, default 'elu') – Activation function used within GRNs and potentially other dense layers. Supported options include ‘elu’, ‘relu’, ‘gelu’, ‘tanh’, ‘sigmoid’, ‘linear’.

  • use_batch_norm (bool, default False) – If True, applies Batch Normalization within the Gated Residual Networks (GRNs). Layer Normalization is typically used by default within GRN implementations as per the TFT paper.

  • num_lstm_layers (int, default 1) – Number of LSTM layers stacked in the sequence encoder module.

  • lstm_units (int or list[int], optional, default None) –

    Number of hidden units in each LSTM encoder layer. * If int: The same number of units is used for all layers

    specified by num_lstm_layers.

    • If list[int]: Specifies the number of units for each LSTM

      layer sequentially. The length must match num_lstm_layers.

    • If None: Defaults to using hidden_units for all LSTM layers.

  • output_dim (int, default 1) – The number of target variables the model predicts at each time step. Typically 1 for univariate forecasting.

  • **kwargs – Additional keyword arguments passed to the parent Keras Model.

Notes

This implementation requires inputs to the call method as a list or tuple containing exactly three tensors in the order: [static_inputs, dynamic_inputs, future_inputs]. The shapes should be: * static_inputs: (Batch, StaticFeatures) * dynamic_inputs: (Batch, PastTimeSteps, DynamicFeatures) * future_inputs: (Batch, TotalTimeSteps, FutureFeatures) where

TotalTimeSteps includes past and future steps relevant for the LSTM processing.

Use Case and Importance

This revised TFT class provides a structured implementation that closely follows the component architecture described in the original TFT paper, including the distinct GRNs for generating static context vectors. By requiring all input types (static, dynamic past, known future), it simplifies the input handling logic compared to versions allowing optional inputs. This makes it a suitable choice when you have all three types of features available and want a robust baseline TFT implementation that explicitly leverages static context for VSNs, LSTM initialization, and temporal processing enrichment. It serves as a strong foundation for complex multi-horizon forecasting tasks that benefit from diverse data integration and interpretability components like VSNs and attention.

Mathematical Formulation

The model processes inputs through the following key stages:

  1. Variable Selection: Separate Variable Selection Networks (VSNs) are applied to the static ($mathbf{s}$), dynamic past ($mathbf{x}_t, t le T$), and known future ($mathbf{z}_t, t > T$) inputs. This step identifies relevant features within each input type and transforms them into embeddings of dimension hidden_units. Let the outputs be $zeta$ (static embedding), $xi^{dyn}_t$ (dynamic), and $xi^{fut}_t$ (future). VSNs may be conditioned by a static context vector $c_s$.

    \[\begin{split}\zeta &= \text{VSN}_{static}(\mathbf{s}, [c_s]) \\ \xi^{dyn}_t &= \text{VSN}_{dyn}(\mathbf{x}_t, [c_s]) \\ \xi^{fut}_t &= \text{VSN}_{fut}(\mathbf{z}_t, [c_s])\end{split}\]
  2. Static Context Generation: Four distinct Gated Residual Networks (GRNs) process the static embedding $zeta$ to produce context vectors: $c_s$ (for VSNs), $c_e$ (for enrichment), $c_h$ (LSTM initial hidden state), $c_c$ (LSTM initial cell state).

    \[c_s = GRN_{vs}(\zeta) \quad ... \quad c_c = GRN_{c}(\zeta)\]
  3. Temporal Processing Input: The selected dynamic and future embeddings are potentially combined (e.g., concatenated along time or features, depending on preprocessing) and augmented with positional encoding to form the input sequence for the LSTM. Let this sequence be $psi_t$.

    \[\psi_t = \text{Combine}(\xi^{dyn}_t, \xi^{fut}_t) + \text{PosEncode}(t)\]
  4. LSTM Encoder: A stack of num_lstm_layers LSTMs processes $psi_t$, initialized with $[c_h, c_c]$.

    \[\{h_t\} = \text{LSTMStack}(\{\psi_t\}, \text{initial_state}=[c_h, c_c])\]
  5. Static Enrichment: The LSTM outputs $h_t$ are combined with the static enrichment context $c_e$ using a time-distributed GRN.

    \[\phi_t = GRN_{enrich}(h_t, c_e)\]
  6. Temporal Self-Attention: Interpretable Multi-Head Attention is applied to the enriched sequence ${phi_t}$, potentially using $c_s$ as context within the attention mechanism’s internal GRNs. This results in context vectors $beta_t$ after residual connection, gating (GLU), and normalization.

    \[\beta_t = \text{TemporalAttention}(\{\phi_t\}, c_s)\]
  7. Position-wise Feed-Forward: A final time-distributed GRN is applied to the attention output.

    \[\delta_t = GRN_{final}(\beta_t)\]
  8. Output Projection: The features corresponding to the forecast horizon ($t > T$) are selected from ${delta_t}$ and passed through a final Dense layer (or multiple layers for quantiles) to produce the predictions $hat{y}_{t+1}, …, hat{y}_{t+tau}$.

call(inputs, training=False)[source]

Performs the forward pass. Expects inputs as a list/tuple: [static_inputs, dynamic_inputs, future_inputs].

compile(optimizer, loss=None, \*\*kwargs)[source]

Compiles the model, automatically selecting ‘mse’ or quantile loss based on quantiles initialization if loss is not given.

Examples

>>> import numpy as np
>>> import tensorflow as tf
>>> from fusionlab.nn.transformers import TFT
>>> from fusionlab.nn.losses import combined_quantile_loss
>>>
>>> # Dummy Data Dimensions
>>> B, T_past, H = 4, 12, 6 # Batch, Lookback, Horizon
>>> D_dyn, D_stat, D_fut = 5, 3, 2
>>> T_future = H # Assume future inputs cover horizon only for LSTM input
>>>
>>> # Create Dummy Input Tensors (Ensure correct shapes and types)
>>> static_in = tf.random.normal((B, D_stat), dtype=tf.float32)
>>> dynamic_in = tf.random.normal((B, T_past, D_dyn), dtype=tf.float32)
>>> # Future input needs shape (B, T_past + T_future, D_fut) for VSN
>>> # or (B, T_future, D_fut) if handled differently before LSTM concat.
>>> # Let's assume preprocessed to match horizon T_future for simplicity here
>>> future_in = tf.random.normal((B, T_future, D_fut), dtype=tf.float32)
>>>
>>> # Instantiate Model for Quantile Forecasting
>>> model = TFT(
...     dynamic_input_dim=D_dyn, static_input_dim=D_stat,
...     future_input_dim=D_fut, forecast_horizon=H,
...     hidden_units=16, num_heads=2, num_lstm_layers=1,
...     quantiles=[0.1, 0.5, 0.9], output_dim=1
... )
>>>
>>> # Compile with appropriate loss
>>> loss_fn = combined_quantile_loss([0.1, 0.5, 0.9])
>>> model.compile(optimizer='adam', loss=loss_fn)
>>>
>>> # Prepare input list in correct order: [static, dynamic, future]
>>> model_inputs = [static_in, dynamic_in, future_in]
>>>
>>> # Make a prediction (forward pass)
>>> # Note: Need to build the model first, e.g., by calling it once
>>> # or specifying input_shape in build method if using subclassing.
>>> # Alternatively, fit for one step. For direct call:
>>> # output_shape = model.compute_output_shape(
>>> #    [t.shape for t in model_inputs]) # Requires TF >= 2.8 approx
>>> # For simplicity, assume model builds on first call
>>> predictions = model(model_inputs, training=False)
>>> print(f"Output shape: {predictions.shape}")
Output shape: (4, 6, 3)

See also

fusionlab.nn.components.VariableSelectionNetwork

Core component for VSN.

fusionlab.nn.components.GatedResidualNetwork

Core component for GRN.

fusionlab.nn.components.TemporalAttentionLayer

Core attention block.

tensorflow.keras.layers.LSTM

Recurrent layer used internally.

fusionlab.nn.losses.combined_quantile_loss

Default loss for quantiles.

fusionlab.nn.utils.reshape_xtft_data

Utility to prepare inputs.

fusionlab.nn.XTFT

More advanced related architecture.

tensorflow.keras.Model

Base Keras model class.

References

[Lim21]

Lim, B., Arık, S. Ö., Loeff, N., & Pfister, T. (2021). Temporal fusion transformers for interpretable multi-horizon time series forecasting. International Journal of Forecasting, 37(4), 1748-1764.

__init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]
Parameters:
  • dynamic_input_dim (int)

  • static_input_dim (int)

  • future_input_dim (int)

  • hidden_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • recurrent_dropout_rate (float)

  • forecast_horizon (int)

  • quantiles (List[float] | None)

  • activation (str)

  • use_batch_norm (bool)

  • num_lstm_layers (int)

  • lstm_units (int | List[int] | None)

  • output_dim (int)

Methods

__init__(dynamic_input_dim, ...[, ...])

add_loss(losses, **kwargs)

Add loss tensor(s), potentially dependent on layer inputs.

add_metric(value[, name])

Adds metric tensor to the layer.

add_update(updates)

Add update op(s), potentially dependent on layer inputs.

add_variable(*args, **kwargs)

Deprecated, do NOT use! Alias for add_weight.

add_weight([name, shape, dtype, ...])

Adds a new variable to the layer.

build(input_shape)

Builds the model based on input shapes received.

build_from_config(config)

Builds the layer's states with the supplied config dict.

call(inputs[, training])

Forward pass for the revised TFT with numerical inputs.

compile(optimizer[, loss])

Configures the model for training.

compile_from_config(config)

Compiles the model with the information given in config.

compute_loss([x, y, y_pred, sample_weight])

Compute the total loss, validate it, and return it.

compute_mask(inputs[, mask])

Computes an output mask tensor.

compute_metrics(x, y, y_pred, sample_weight)

Update metric states and collect all metrics to be returned.

compute_output_shape(input_shape)

Computes the output shape of the layer.

compute_output_signature(input_signature)

Compute the output tensor signature of the layer based on the inputs.

count_params()

Count the total number of scalars composing the weights.

evaluate([x, y, batch_size, verbose, ...])

Returns the loss value & metrics values for the model in test mode.

evaluate_generator(generator[, steps, ...])

Evaluates the model on a data generator.

export(filepath)

Create a SavedModel artifact for inference (e.g. via TF-Serving).

finalize_state()

Finalizes the layers state after updating layer weights.

fit([x, y, batch_size, epochs, verbose, ...])

Trains the model for a fixed number of epochs (dataset iterations).

fit_generator(generator[, steps_per_epoch, ...])

Fits the model on data yielded batch-by-batch by a Python generator.

from_config(config)

Creates a layer from its config.

get_build_config()

Returns a dictionary with the layer's input shape.

get_compile_config()

Returns a serialized config with information for compiling the model.

get_config()

Returns the config of the Model.

get_input_at(node_index)

Retrieves the input tensor(s) of a layer at a given node.

get_input_mask_at(node_index)

Retrieves the input mask tensor(s) of a layer at a given node.

get_input_shape_at(node_index)

Retrieves the input shape(s) of a layer at a given node.

get_layer([name, index])

Retrieves a layer based on either its name (unique) or index.

get_metrics_result()

Returns the model's metrics values as a dict.

get_output_at(node_index)

Retrieves the output tensor(s) of a layer at a given node.

get_output_mask_at(node_index)

Retrieves the output mask tensor(s) of a layer at a given node.

get_output_shape_at(node_index)

Retrieves the output shape(s) of a layer at a given node.

get_params([deep])

Get the parameters for this learner.

get_weight_paths()

Retrieve all the variables and their paths for the model.

get_weights()

Retrieves the weights of the model.

help(**kwargs)

load(file_path[, format])

Load the learner's state from a specified file in the desired format.

load_own_variables(store)

Loads the state of the layer.

load_weights(filepath[, skip_mismatch, ...])

Loads all layer weights from a saved files.

make_predict_function([force])

Creates a function that executes one step of inference.

make_test_function([force])

Creates a function that executes one step of evaluation.

make_train_function([force])

Creates a function that executes one step of training.

predict(x[, batch_size, verbose, steps, ...])

Generates output predictions for the input samples.

predict_generator(generator[, steps, ...])

Generates predictions for the input samples from a data generator.

predict_on_batch(x)

Returns predictions for a single batch of samples.

predict_step(data)

The logic for one inference step.

reset_metrics()

Resets the state of all the metrics in the model.

reset_states()

save(filepath[, overwrite, save_format])

Saves a model as a TensorFlow SavedModel or HDF5 file.

save_own_variables(store)

Saves the state of the layer.

save_spec([dynamic_batch])

Returns the tf.TensorSpec of call args as a tuple (args, kwargs).

save_weights(filepath[, overwrite, ...])

Saves all layer weights.

set_params(**params)

Set the parameters of this learner.

set_weights(weights)

Sets the weights of the layer, from NumPy arrays.

summary([line_length, positions, print_fn, ...])

Prints a string summary of the network.

test_on_batch(x[, y, sample_weight, ...])

Test the model on a single batch of samples.

test_step(data)

The logic for one evaluation step.

to_json(**kwargs)

Returns a JSON string containing the network configuration.

to_yaml(**kwargs)

Returns a yaml string containing the network configuration.

train_on_batch(x[, y, sample_weight, ...])

Runs a single gradient update on a single batch of data.

train_step(data)

The logic for one training step.

with_name_scope(method)

Decorator to automatically enter the module name scope.

Attributes

activity_regularizer

Optional regularizer function for the output of this layer.

autotune_steps_per_execution

Settable property to enable tuning for steps_per_execution

compute_dtype

The dtype of the layer's computations.

distribute_reduction_method

The method employed to reduce per-replica values during training.

distribute_strategy

The tf.distribute.Strategy this model was created under.

dtype

The dtype of the layer weights.

dtype_policy

The dtype policy associated with this layer.

dynamic

Whether the layer is dynamic (eager-only); set in the constructor.

inbound_nodes

Return Functional API nodes upstream of this layer.

input

Retrieves the input tensor(s) of a layer.

input_mask

Retrieves the input mask tensor(s) of a layer.

input_shape

Retrieves the input shape(s) of a layer.

input_spec

InputSpec instance(s) describing the input format for this layer.

jit_compile

Specify whether to compile the model with XLA.

layers

losses

List of losses added using the add_loss() API.

metrics

Return metrics added using compile() or add_metric().

metrics_names

Returns the model's display labels for all outputs.

my_params

name

Name of the layer (string), set in the constructor.

name_scope

Returns a tf.name_scope instance for this class.

non_trainable_variables

Sequence of non-trainable variables owned by this module and its submodules.

non_trainable_weights

List of all non-trainable weights tracked by this layer.

outbound_nodes

Return Functional API nodes downstream of this layer.

output

Retrieves the output tensor(s) of a layer.

output_mask

Retrieves the output mask tensor(s) of a layer.

output_shape

Retrieves the output shape(s) of a layer.

run_eagerly

Settable attribute indicating whether the model should run eagerly.

state_updates

Deprecated, do NOT use!

stateful

steps_per_execution

Settable `steps_per_execution variable. Requires a compiled model.

submodules

Sequence of all sub-modules.

supports_masking

Whether this layer supports computing a mask using compute_mask.

trainable

trainable_variables

Sequence of trainable variables owned by this module and its submodules.

trainable_weights

List of all trainable weights tracked by this layer.

updates

variable_dtype

Alias of Layer.dtype, the dtype of the weights.

variables

Returns the list of all layer variables/weights.

weights

Returns the list of all layer variables/weights.

__init__(dynamic_input_dim, static_input_dim, future_input_dim, hidden_units=32, num_heads=4, dropout_rate=0.1, recurrent_dropout_rate=0.0, forecast_horizon=1, quantiles=None, activation='elu', use_batch_norm=False, num_lstm_layers=1, lstm_units=None, output_dim=1, **kwargs)[source]
Parameters:
  • dynamic_input_dim (int)

  • static_input_dim (int)

  • future_input_dim (int)

  • hidden_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • recurrent_dropout_rate (float)

  • forecast_horizon (int)

  • quantiles (List[float] | None)

  • activation (str)

  • use_batch_norm (bool)

  • num_lstm_layers (int)

  • lstm_units (int | List[int] | None)

  • output_dim (int)

call(inputs, training=None)[source]

Forward pass for the revised TFT with numerical inputs.

compile(optimizer, loss=None, **kwargs)[source]

Configures the model for training.

Example:

```python model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),

loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy(),

tf.keras.metrics.FalseNegatives()])

```

Parameters:
  • optimizer – String (name of optimizer) or optimizer instance. See tf.keras.optimizers.

  • loss – Loss function. May be a string (name of loss function), or a tf.keras.losses.Loss instance. See tf.keras.losses. A loss function is any callable with the signature loss = fn(y_true, y_pred), where y_true are the ground truth values, and y_pred are the model’s predictions. y_true should have shape (batch_size, d0, .. dN) (except in the case of sparse loss functions such as sparse categorical crossentropy which expects integer arrays of shape (batch_size, d0, .. dN-1)). y_pred should have shape (batch_size, d0, .. dN). The loss function should return a float tensor. If a custom Loss instance is used and reduction is set to None, return value has shape (batch_size, d0, .. dN-1) i.e. per-sample or per-timestep loss values; otherwise, it is a scalar. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses, unless loss_weights is specified.

  • metrics – List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a tf.keras.metrics.Metric instance. See tf.keras.metrics. Typically you will use metrics=[‘accuracy’]. A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such as metrics={‘output_a’:’accuracy’, ‘output_b’:[‘accuracy’, ‘mse’]}. You can also pass a list to specify a metric or a list of metrics for each output, such as metrics=[[‘accuracy’], [‘accuracy’, ‘mse’]] or metrics=[‘accuracy’, [‘accuracy’, ‘mse’]]. When you pass the strings ‘accuracy’ or ‘acc’, we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the shapes of the targets and of the model output. We do a similar conversion for the strings ‘crossentropy’ and ‘ce’ as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the weighted_metrics argument instead.

  • loss_weights – Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model’s outputs. If a dict, it is expected to map output names (strings) to scalar coefficients.

  • weighted_metrics – List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing.

  • run_eagerly

    Bool. If True, this Model’s logic will not be wrapped in a tf.function. Recommended to leave this as None unless your Model cannot be run inside a tf.function. run_eagerly=True is not supported when using tf.distribute.experimental.ParameterServerStrategy. Defaults to

    False.

  • steps_per_execution – Int or ‘auto’. The number of batches to run during each tf.function call. If set to “auto”, keras will automatically tune steps_per_execution during runtime. Running multiple batches inside a single tf.function call can greatly improve performance on TPUs, when used with distributed strategies such as ParameterServerStrategy, or with small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if steps_per_execution is set to N, Callback.on_batch_begin and Callback.on_batch_end methods will only be called every N batches (i.e. before/after each tf.function execution). Defaults to 1.

  • jit_compile – If True, compile the model training step with XLA. [XLA](https://www.tensorflow.org/xla) is an optimizing compiler for machine learning. jit_compile is not enabled for by default. Note that jit_compile=True may not necessarily work for all models. For more information on supported operations please refer to the [XLA documentation](https://www.tensorflow.org/xla). Also refer to [known XLA issues](https://www.tensorflow.org/xla/known_issues) for more details.

  • pss_evaluation_shards – Integer or ‘auto’. Used for tf.distribute.ParameterServerStrategy training only. This arg sets the number of shards to split the dataset into, to enable an exact visitation guarantee for evaluation, meaning the model will be applied to each dataset element exactly once, even if workers fail. The dataset must be sharded to ensure separate workers do not process the same data. The number of shards should be at least the number of workers for good performance. A value of ‘auto’ turns on exact evaluation and uses a heuristic for the number of shards based on the number of workers. 0, meaning no visitation guarantee is provided. NOTE: Custom implementations of Model.test_step will be ignored when doing exact evaluation. Defaults to 0.

  • **kwargs – Arguments supported for backwards compatibility only.

get_config()[source]

Returns the config of the Model.

Config is a Python dictionary (serializable) containing the configuration of an object, which in this case is a Model. This allows the Model to be be reinstantiated later (without its trained weights) from this configuration.

Note that get_config() does not guarantee to return a fresh copy of dict every time it is called. The callers should make a copy of the returned dict if they want to modify it.

Developers of subclassed Model are advised to override this method, and continue to update the dict from super(MyModel, self).get_config() to provide the proper configuration of this Model. The default config will return config dict for init parameters if they are basic types. Raises NotImplementedError when in cases where a custom get_config() implementation is required for the subclassed model.

Returns:

Python dictionary containing the configuration of this Model.

classmethod from_config(config)[source]

Creates a layer from its config.

This method is the reverse of get_config, capable of instantiating the same layer from the config dictionary. It does not handle layer connectivity (handled by Network), nor weights (handled by set_weights).

Parameters:

config – A Python dictionary, typically the output of get_config.

Returns:

A layer instance.

help(**kwargs)
my_params = TFT(     dynamic_input_dim,     static_input_dim,     future_input_dim,     hidden_units=32,     num_heads=4,     dropout_rate=0.1,     recurrent_dropout_rate=0.0,     forecast_horizon=1,     quantiles=None,     activation='elu',     use_batch_norm=False,     num_lstm_layers=1,     lstm_units=None,     output_dim=1 )