Tuning HALNet with the XTFTTuner¶
The HALNet model, with its rich set of
components like MultiScaleLSTM and multiple attention layers, has
many hyperparameters that can be optimized to achieve peak performance
on a given dataset.
This guide explains how to use the XTFTTuner to automatically find
the best set of hyperparameters for your HALNet model.
Note
Because HALNet shares its core data-driven architecture with
the more advanced XTFT model, the same ``XTFTTuner`` can be
used to tune both. The process is nearly identical, with the primary
difference being the set of hyperparameters defined in the
search_space.
End-to-End Workflow¶
The process involves preparing your data, defining a search space
specific to HALNet, and then using the XTFTTuner to run the
optimization process.
Step 1: Prepare Data¶
First, ensure your data is prepared in the three-part format required
by HALNet: dictionaries of NumPy arrays for static, dynamic past,
and known future features, along with a corresponding targets array.
1import numpy as np
2
3# --- Define Data Dimensions for the Example ---
4B, T_PAST, HORIZON = 256, 21, 7
5S_DIM, D_DIM, F_DIM = 4, 6, 3
6O_DIM = 1
7
8# --- Generate Dummy Data Arrays ---
9static_features = np.random.rand(B, S_DIM)
10dynamic_features = np.random.rand(B, T_PAST, D_DIM)
11# For 'tft_like' mode, future features span past + horizon
12future_features = np.random.rand(B, T_PAST + HORIZON, F_DIM)
13targets = np.random.rand(B, HORIZON, O_DIM)
14
15# Create a validation split
16val_split = -50
17train_inputs = [arr[:val_split] for arr in [static_features, dynamic_features, future_features]]
18val_inputs = [arr[val_split:] for arr in [static_features, dynamic_features, future_features]]
19train_targets, val_targets = targets[:val_split], targets[val_split:]
20
21print("Generated dummy data for HALNet tuning.")
Step 2: Define the HALNet Search Space¶
This dictionary is the core of the tuning experiment. Here, we define all the architectural and optimization hyperparameters we want the tuner to explore.
When tuning HALNet, we simply omit any hyperparameters that are
specific to XTFT, such as anomaly_detection_strategy or
anomaly_loss_weight.
1halnet_search_space = {
2 # --- Architectural Hyperparameters ---
3 "embed_dim": [32, 64],
4 "hidden_units": [32, 64],
5 "lstm_units": [32, 64],
6 "attention_units": [16, 32],
7 "num_heads": {"type": "choice", "values": [2, 4]},
8 "dropout_rate": {"type": "float", "min_value": 0.1, "max_value": 0.4},
9 "use_vsn": {"type": "bool"}, # Tune whether to use VSN or not
10
11 # --- Compile-time Hyperparameters ---
12 "learning_rate": [1e-3, 5e-4]
13}
14print("Defined hyperparameter search space for HALNet.")
### Step 3: Create and Run the Tuner
Now, we use the XTFTTuner.create() factory method. The key step is
to pass the ``HALNet`` class to the model_name_or_cls argument.
The tuner will intelligently adapt and build HALNet instances during
the search.
1import tensorflow as tf
2from fusionlab.nn.forecast_tuner import XTFTTuner # Use the XTFT Tuner
3from fusionlab.nn.models import HALNet # But for the HALNet model
4
5# 1. Create the tuner instance, passing the HALNet class
6tuner = XTFTTuner.create(
7 model_name_or_cls=HALNet, # <-- Specify HALNet here
8 inputs_data={"static": static_features, "dynamic": dynamic_features},
9 targets_data=targets,
10 search_space=halnet_search_space,
11 # Provide any fixed params that shouldn't be tuned
12 fixed_params={
13 "future_input_dim": F_DIM,
14 "mode": "tft_like",
15 "max_window_size": T_PAST
16 },
17 # Keras Tuner settings
18 objective="val_loss",
19 max_trials=5, # Use a small number for this example
20 project_name="HALNet_Tuning_Example",
21 directory="./halnet_tuner_results",
22 overwrite=True
23)
24
25# 2. Run the search process
26print("\nStarting hyperparameter search for HALNet...")
27best_model, best_hps, _ = tuner.run(
28 inputs=train_inputs,
29 y=train_targets,
30 validation_data=(val_inputs, val_targets),
31 epochs=5,
32 batch_size=64,
33 callbacks=[tf.keras.callbacks.EarlyStopping('val_loss', patience=3)]
34)
Step 4: Analyze the Results¶
After the search completes, you can inspect the best hyperparameters found for your HALNet model.
1print("\n--- Tuning Complete: Best Hyperparameters for HALNet ---")
2if best_hps:
3 for hp, value in best_hps.values.items():
4 if isinstance(value, float):
5 print(f" - {hp}: {value:.4f}")
6 else:
7 print(f" - {hp}: {value}")
8else:
9 print("Search did not find any best hyperparameters.")
Expected Output:
--- Tuning Complete: Best Hyperparameters for HALNet ---
- embed_dim: 32
- hidden_units: 64
- lstm_units: 32
- attention_units: 16
- num_heads: 2
- dropout_rate: 0.1827
- use_vsn: True
- learning_rate: 0.0010
This workflow demonstrates that the modular design of the tuning utilities allows them to be flexibly applied to different but related model architectures, accelerating the path to an optimized model.