fusionlab.nn.models.HALNet¶
- class fusionlab.nn.models.HALNet[source]¶
Bases:
Model,NNLearnerHybrid Attentive LSTM Network (HAL-Net).
A data‑driven encoder–decoder architecture that couples multi‑scale LSTMs with hierarchical attention to deliver accurate multi‑horizon forecasts from static, dynamic‑past, and known‑future covariates. HAL‑Net is a pared‑down variant of PIHALNet: it omits physics constraints and anomaly modules, making it suitable for purely statistical settings.
See more in User Guide.
- Parameters:
static_input_dim (
int) – Dimensionality of the static (time-invariant) input features. These are features that do not change over time for a given sample, such as a sensor’s location ID, soil type, or a product category. If 0, no static features are used.dynamic_input_dim (
int) – Dimensionality of the dynamic (time-varying) input features that are known in the past (the “lookback” window). This is a required parameter and typically includes the target variable itself (lagged) and other historical drivers like rainfall, temperature, or sales figures.future_input_dim (
int) – Dimensionality of the time-varying features for which values are known in advance for the forecast period. Examples include holidays, scheduled promotions, or day-of-week indicators. If 0, no future features are used.output_dim (
int, default1) –Number of target variables produced at each forecast step. The model outputs a tensor of shape \((B, \, H, \, Q, \, \text{output\_dim})\) when quantiles are provided, or \((B, \, H, \, \text{output\_dim})\) for point forecasts, where
\[B = \text{batch size},\qquad H = \text{forecast horizon},\qquad Q = |\text{quantiles}|.\]forecast_horizon (
int, default1) – Length of the prediction window into the future. The dynamic encoder ingests max_window_size past steps and the decoder emits \(H\) steps ahead, where \(H=\text{forecast\_horizon}\). Setting \(H > 1\) enables multi‑horizon sequence‑to‑sequence forecasts.quantiles (
list[float]orNone, defaultNone) –Optional quantile levels \(0 < q_1 < \dots < q_Q < 1\). When supplied, a
fusionlab.nn.components.QuantileDistributionModelinghead scales the point forecast \(\hat{y}\) into quantile estimates\[\hat{y}^{(q)} = \hat{y} + \sigma \,\Phi^{-1}(q),\]where \(\sigma\) is a learned spread parameter and \(\Phi^{-1}\) is the probit function. Omit or set to None to obtain deterministic forecasts.
embed_dim (
int, default32) – The base dimensionality for the internal feature space of the model. Various input features (static, dynamic, future) are projected into this common dimension to allow for meaningful interactions within downstream layers like LSTMs and attention mechanisms. It’s a key parameter for controlling model capacity.hidden_units (
int, default64) – The number of units in the hidden layers of the Gated Residual Networks (GRNs). GRNs are core components used for non-linear transformations throughout the architecture. A larger value increases the model’s capacity to learn complex patterns.lstm_units (
int, default64) – The number of hidden units in each LSTM layer within theMultiScaleLSTMblock. This parameter determines the memory capacity of the recurrent cells processing the historical sequence data.attention_units (
int, default32) – The dimensionality of the output space for the various attention mechanisms (e.g., CrossAttention, HierarchicalAttention). This is also often referred to as the model’s dimension, \(d_{model}\). It must be divisible by num_heads.num_heads (
int, default4) – The number of attention heads in each MultiHeadAttention sub-layer. Using multiple heads allows the model to jointly attend to information from different representation subspaces at different positions, which can improve learning.dropout_rate (
float, default0.1) – The dropout rate applied within various components like Gated Residual Networks (GRNs) and after some attention layers to prevent overfitting. It must be a float between 0.0 and 1.0.max_window_size (
int, default10) – The number of past time steps (the lookback window) that the model considers. This should directly correspond to the time_steps parameter used during data preparation and is used by components likeDynamicTimeWindow.memory_size (
int, default100) – The number of memory slots in theMemoryAugmentedAttentionlayer. This external memory allows the model to learn and access patterns over very long-range dependencies that might be missed by standard LSTMs or attention.scales (
listofint, optional) – A list of scale factors for theMultiScaleLSTM. Each scale s creates an LSTM that processes the input sequence by taking every s-th time step. For example, scales=[1, 3] would process the sequence at its original resolution and at a coarser, every-third-timestep resolution. If None or ‘auto’, defaults to [1].multi_scale_agg (
{'last', 'average', 'concat', ...}, default'last') –The strategy used by the aggregation function to combine the outputs from the different LSTMs in MultiScaleLSTM. -
'concat': (For 3D output) Pads sequences from differentscales to the same length and concatenates them along the feature axis. This is the primary mode for creating a rich sequence representation for downstream attention layers in an encoder-decoder setup.
'last'or'auto': (For 2D output) Creates a context vector by taking the last hidden state from each LSTM scale and concatenating them.'average'or'sum': Create a 2D context vector by averaging or summing over the time dimension for each scale.
final_agg (
{'last', 'average', 'flatten'}, default'last') – The aggregation strategy used to collapse the final temporal feature map (which has a time dimension equal to forecast_horizon) into a single feature vector before the final decoding step.activation (
str, default'relu') – The name of the activation function to use in Dense layers and Gated Residual Networks (GRNs) throughout the model. Common choices include ‘relu’, ‘gelu’, ‘swish’, and ‘tanh’.use_residuals (
bool, defaultTrue) – If True, enables residual “add & norm” connections after key sub-layers (like attention and GRNs). These shortcut connections are crucial for training very deep networks as they help prevent vanishing gradients and ease the optimization process.use_vsn (
bool, defaultTrue) – If True, the model usesVariableSelectionNetwork(VSN) layers at the input stage. VSNs perform intelligent, learnable feature selection, allowing the model to up-weight or down-weight the importance of each input variable. This can improve performance and provide insights into which features are most impactful. If False, simpler Dense layers are used for initial projection.vsn_units (
int, optional) – The number of units in the internal Gated Residual Networks (GRNs) of the Variable Selection Networks. This parameter controls the capacity of the feature selection sub-networks. If None, it often defaults to a value based on hidden_units.
- mode{‘pihal_like’, ‘tft_like’}, default
'tft_like' Controls how future_features are sliced and routed.
'pihal_like'expectsfuture_input.shape[1] == forecast_horizonand feeds the tensor only to the decoder.'tft_like'expectstime_steps + forecast_horizonrows, sending the first time_steps rows to the encoder and the remaining rows to the decoder, emulating the Temporal Fusion Transformer.- namestr, default
"HALNet" Model identifier passed to :pyclass:`tf.keras.Model`. Appears in weight filenames and TensorBoard scopes.
- **kwargs
Additional keyword arguments forwarded verbatim to the :pyclass:`tf.keras.Model` constructor—e.g.
dtype="float64"orrun_eagerly=True.
Notes
The composite latent size produced by the cross‑attention block is \(d_\text{model} = \text{attention\_units}\). For stable training ensure \(d_\text{model}\) is divisible by num_heads.
See also
fusionlab.nn.pinn.PIHALNet– physics‑informed extension.fusionlab.utils.data_utils.widen_temporal_columns()– prepares wide data frames for plotting forecasts.
Examples
>>> from fusionlab.nn.pinn import HALNet >>> model = HALNet( ... static_input_dim=4, dynamic_input_dim=8, future_input_dim=6, ... output_dim=2, forecast_horizon=24, quantiles=[0.1, 0.5, 0.9], ... scales=[1, 3], multi_scale_agg="concat", final_agg="last", ... attention_units=64, num_heads=8, dropout_rate=0.15, ... ) >>> x_static = tf.random.normal([32, 4]) # B × S >>> x_dynamic = tf.random.normal([32, 10, 8]) # B × T × D >>> x_future = tf.random.normal([32, 24, 6]) # B × H × F >>> y_hat = model({ ... "static_features": x_static, ... "dynamic_features": x_dynamic, ... "future_features": x_future, ... "coords": tf.zeros([32, 24, 3]), # dummy (t, x, y) ... }) >>> y_hat["subs_pred"].shape TensorShape([32, 24, 3, 2]) # B × H × Q × output_dim
See also
fusionlab.nn.pinn.PIHALNet,fusionlab.nn.components.MultiScaleLSTM,fusionlab.nn.components.VariableSelectionNetworkReferences
- __init__(static_input_dim, dynamic_input_dim, future_input_dim, output_dim=1, forecast_horizon=1, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, max_window_size=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, mode=None, name='HALNet', **kwargs)[source]¶
- Parameters:
static_input_dim (int)
dynamic_input_dim (int)
future_input_dim (int)
output_dim (int)
forecast_horizon (int)
quantiles (List[float] | None)
embed_dim (int)
hidden_units (int)
lstm_units (int)
attention_units (int)
num_heads (int)
dropout_rate (float)
max_window_size (int)
memory_size (int)
scales (List[int] | None)
multi_scale_agg (str)
final_agg (str)
activation (str)
use_residuals (bool)
use_vsn (bool)
vsn_units (int | None)
mode (str | None)
name (str)
Methods
__init__(static_input_dim, ...[, ...])add_loss(loss)Can be called inside of the call() method to add a scalar loss.
add_metric(*args, **kwargs)add_variable(shape, initializer[, dtype, ...])Add a weight variable to the layer.
add_weight([shape, initializer, dtype, ...])Add a weight variable to the layer.
build(input_shape)build_from_config(config)Builds the layer's states with the supplied config dict.
call(inputs[, training])Forward pass of HALNet that produces point or quantile forecasts.
compile([optimizer, loss, loss_weights, ...])Configures the model for training.
compile_from_config(config)Compiles the model with the information given in config.
compiled_loss(y, y_pred[, sample_weight, ...])compute_loss([x, y, y_pred, sample_weight, ...])Compute the total loss, validate it, and return it.
compute_mask(inputs, previous_mask)compute_metrics(x, y, y_pred[, sample_weight])Update metric states and collect all metrics to be returned.
compute_output_shape(*args, **kwargs)compute_output_spec(*args, **kwargs)count_params()Count the total number of scalars composing the weights.
evaluate([x, y, batch_size, verbose, ...])Returns the loss value & metrics values for the model in test mode.
export(filepath[, format, verbose, ...])Export the model as an artifact for inference.
fit([x, y, batch_size, epochs, verbose, ...])Trains the model for a fixed number of epochs (dataset iterations).
from_config(config[, custom_objects])Re‑instantiate a HALNet object from a configuration dictionary produced by :pyfunc:`get_config`.
get_build_config()Returns a dictionary with the layer's input shape.
get_compile_config()Returns a serialized config with information for compiling the model.
Serialize the HALNet configuration for :pyfunc:`tf.keras.models.clone_model` or model saving.
get_layer([name, index])Retrieves a layer based on either its name (unique) or index.
get_metrics_result()Returns the model's metrics values as a dict.
get_params([deep])Get the parameters for this learner.
get_state_tree([value_format])Retrieves tree-like structure of model variables.
get_weights()Return the values of layer.weights as a list of NumPy arrays.
help(**kwargs)load(file_path[, format])Load the learner's state from a specified file in the desired format.
load_own_variables(store)Loads the state of the layer.
load_weights(filepath[, skip_mismatch])Load the weights from a single file or sharded files.
loss(y, y_pred[, sample_weight])make_predict_function([force])make_test_function([force])make_train_function([force])predict(x[, batch_size, verbose, steps, ...])Generates output predictions for the input samples.
predict_on_batch(x)Returns predictions for a single batch of samples.
predict_step(data)quantize(mode[, config])Quantize the weights of the model.
quantized_build(input_shape, mode)quantized_call(*args, **kwargs)rematerialized_call(layer_call, *args, **kwargs)Enable rematerialization dynamically for layer's call method.
reset_metrics()run_halnet_core(static_input, dynamic_input, ...)Execute the data‑driven encoder–decoder backbone of HALNet.
save(filepath[, overwrite, zipped])Saves a model as a .keras file.
save_own_variables(store)Saves the state of the layer.
save_weights(filepath[, overwrite, ...])Saves all weights to a single file or sharded files.
set_params(**params)Set the parameters of this learner.
set_state_tree(state_tree)Assigns values to variables of the model.
set_weights(weights)Sets the values of layer.weights from a list of NumPy arrays.
stateless_call(trainable_variables, ...[, ...])Call the layer without any side effects.
stateless_compute_loss(trainable_variables, ...)summary([line_length, positions, print_fn, ...])Prints a string summary of the network.
symbolic_call(*args, **kwargs)test_on_batch(x[, y, sample_weight, return_dict])Test the model on a single batch of samples.
test_step(data)to_json(**kwargs)Returns a JSON string containing the network configuration.
train_on_batch(x[, y, sample_weight, ...])Runs a single gradient update on a single batch of data.
train_step(data)Attributes
compiled_metricscompute_dtypeThe dtype of the computations performed by the layer.
distribute_reduction_methoddistribute_strategydtypeAlias of layer.variable_dtype.
dtype_policyinputRetrieves the input tensor(s) of a symbolic operation.
input_dtypeThe dtype layer inputs should be converted to.
input_specjit_compilelayerslossesList of scalar losses from add_loss, regularizers and sublayers.
metricsList of all metrics.
metrics_namesmetrics_variablesList of all metric variables.
non_trainable_variablesList of all non-trainable layer state.
non_trainable_weightsList of all non-trainable weight variables of the layer.
outputRetrieves the output tensor(s) of a layer.
pathThe path of the layer.
quantization_modeThe quantization mode of this layer, None if not quantized.
run_eagerlysupports_maskingWhether this layer supports computing a mask using compute_mask.
trainableSettable boolean, whether this layer should be trainable or not.
trainable_variablesList of all trainable layer state.
trainable_weightsList of all trainable weight variables of the layer.
variable_dtypeThe dtype of the state (weights) of the layer.
variablesList of all layer state, including random seeds.
weightsList of all weight variables of the layer.
- __init__(static_input_dim, dynamic_input_dim, future_input_dim, output_dim=1, forecast_horizon=1, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, max_window_size=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, mode=None, name='HALNet', **kwargs)[source]¶
- Parameters:
static_input_dim (int)
dynamic_input_dim (int)
future_input_dim (int)
output_dim (int)
forecast_horizon (int)
quantiles (List[float] | None)
embed_dim (int)
hidden_units (int)
lstm_units (int)
attention_units (int)
num_heads (int)
dropout_rate (float)
max_window_size (int)
memory_size (int)
scales (List[int] | None)
multi_scale_agg (str)
final_agg (str)
activation (str)
use_residuals (bool)
use_vsn (bool)
vsn_units (int | None)
mode (str | None)
name (str)
- run_halnet_core(static_input, dynamic_input, future_input, training)[source]¶
Execute the data‑driven encoder–decoder backbone of HALNet.
The routine handles three covariate blocks—static, past‑dynamic, and known‑future—then applies multi‑scale LSTM encoding, several attention modules, and an optional residual fusion before returning a single representation per sample–horizon pair.
The behaviour of future_input depends on :pyattr:`self.mode`:
'tft_like'future_input.shape[1]must equal \(T_\text{past} + H\). The first \(T_\text{past}\) rows join the encoder; the last \(H\) rows feed the decoder, mirroring the original Temporal Fusion Transformer.'pihal_like'future_input.shape[1]must equal \(H\). Future covariates are used only in the decoder (PIHALNet style).
- Parameters:
static_input (
Tensor) – Batch of time‑invariant features,(B, static_input_dim).dynamic_input (
Tensor) – Historical covariates,(B, T_past, dynamic_input_dim).future_input (
Tensor) – Known future covariates whose temporal span is dictated by mode (see above).training (
bool) – Keras training flag passed to all dropout‑bearing layers.
- Returns:
Collapsed decoder context,
(B, forecast_horizon, attention_units)ifself.final_agg is Noneor(B, attention_units)when an aggregator such as'mean'is selected.- Return type:
Tensor
Notes
Variable Selection Networks (VSNs) are applied only when :pyattr:`self.use_vsn` is True.
Residual connections and their normalisations are created when :pyattr:`self.use_residuals` is True.
The final call to :pyfunc:`fusionlab.ops.aggregate_time_window_output` collapses or pools along the horizon dimension according to :pyattr:`self.final_agg`.
See also
fusionlab.layers.VariableSelectionNetwork,fusionlab.layers.MultiScaleLSTM,fusionlab.layers.CrossAttention
- call(inputs, training=False)[source]¶
Forward pass of HALNet that produces point or quantile forecasts.
The method
Validates and casts the
static,dynamicandfutureinput blocks via :pyfunc:`validate_model_inputs`.Verifies that the temporal length of future_input matches the requirement imposed by :pyattr:`self.mode` (
'tft_like'or'pihal_like') using the graph‑compatible :pyfunc:`tf.debugging.assert_equal`.Extracts high‑level features with :pyfunc:`run_halnet_core`.
Decodes those features into either a deterministic forecast or a full quantile distribution.
- Parameters:
inputs (
list[Tensor | None]) – Ordered list[static, dynamic, future]that originates from :pyfunc:`process_pinn_inputs`.training (
bool, defaultFalse) – Keras training flag propagated to every dropout‑bearing layer.
- Returns:
If :pyattr:`self.quantiles` is None:
(B, H, output_dim)– mean forecast.Otherwise:
(B, H, Q, output_dim)– stacked quantile forecasts.
- Return type:
Tensor- Raises:
tf.errors.InvalidArgumentError – When the time dimension of future_input is inconsistent with the selected mode.
ValueError – When
validate_model_inputsrejects the shape signature.
Notes
The function is decorated with
@tf_autograph.experimental.do_not_converthigher up in the class to keep the eager semantics intact.Examples
>>> preds = model([s, d, f], training=False) >>> preds.shape TensorShape([32, 24, 3, 1])
See also
fusionlab.nn.pinn.HALNet.run_halnet_core,fusionlab.utils.validate_model_inputs
- get_config()[source]¶
Serialize the HALNet configuration for :pyfunc:`tf.keras.models.clone_model` or model saving.
All hyper‑parameters required to reconstruct the network are returned as plain Python types compatible with JSON/YAML serialization.
- Returns:
A mapping that can be fed to :pyfunc:`from_config` to recreate an identical model.
- Return type:
dict
Examples
>>> cfg = model.get_config() >>> clone = HALNet.from_config(cfg) >>> np.allclose( ... model(np.random.rand(1, 5, 3)), ... clone(np.random.rand(1, 5, 3)), ... atol=1e-5, ... ) True
- classmethod from_config(config, custom_objects=None)[source]¶
Re‑instantiate a HALNet object from a configuration dictionary produced by :pyfunc:`get_config`.
- Parameters:
config (
dict) – Dictionary returned by :pyfunc:`get_config`.custom_objects (
dictorNone, defaultNone) – Optional mapping for resolving custom layers or functions during deserialization (passed through to :pyfunc:`tf.keras.Model.from_config`).
- Returns:
A new model instance built from config.
- Return type:
Examples
>>> cfg = model.get_config() >>> new_model = HALNet.from_config(cfg)
- help(**kwargs)¶
- my_params = HALNet( static_input_dim, dynamic_input_dim, future_input_dim, output_dim=1, forecast_horizon=1, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, max_window_size=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, mode=None, name='HALNet' )¶