TFT Forecast Tuner Guide¶
Finding the optimal set of hyperparameters for deep learning models
like TemporalFusionTransformer,
TFT (stricter version),
XTFT, and
SuperXTFT is crucial for achieving the best
possible forecasting performance. Hyperparameters control aspects of
the model architecture (e.g., number of hidden units, attention
heads) and the training process (e.g., learning rate, batch size).
fusionlab provides utility functions and classes within the
forecast_tuner module that leverage the
powerful Keras Tuner library (keras-tuner) to automate this
search process.
Prerequisites¶
To use the tuning functions, you must have Keras Tuner installed:
pip install keras-tuner -q
Choosing Your Tuning Approach: Functions vs. Classes¶
As of version 0.2.3, fusionlab-learn offers two primary ways to
perform hyperparameter tuning:
Class-Based Tuners (Recommended for New Projects): This modern, object-oriented approach utilizes dedicated tuner classes like
XTFTTunerandTFTTuner. It offers improved structure, reusability (configure once, fit multiple times), and flexibility. For detailed information and examples, please refer to the Class-Based Forecast Tuner Guide.Function-Based Tuners (Legacy, Still Supported): The original approach uses standalone functions like
xtft_tuner()andtft_tuner(). These functions remain fully supported for backward compatibility and are detailed in the sections below.
For new development, we recommend exploring the class-based tuners for a more streamlined and maintainable workflow.
Tuner Functions (Legacy Approach)¶
The fusionlab.nn.forecast_tuner module offers dedicated
functions to tune different model types. These are suitable for
quick experiments or maintaining existing codebases.
xtft_tuner¶
- API Reference:
Purpose:
To perform hyperparameter optimization for the
XTFT and
SuperXTFT models. It can also be used
to tune TFT (stricter) and
TemporalFusionTransformer (flexible) by
specifying the appropriate model_name.
Functionality: This function orchestrates the tuning process:
Inputs: Takes prepared input data as a list inputs = [X_static, X_dynamic, X_future] (where static or future can be None if model_name=’tft_flex’) and the target array y.
Search Space: Uses a default space (DEFAULT_PS) for common hyperparameters. Users can provide their own param_space dictionary to override or extend these.
Model Builder: Employs an internal default
_model_builder_factory()(or a user-provided model_builder) to construct model instances for given hyperparameters (hp). The builder samples values using Keras Tuner’s hp object. Models are compiled with Adam optimizer and an appropriate loss (MSE or quantile).Tuner Initialization: Creates a Keras Tuner instance (RandomSearch or BayesianOptimization) configured with the objective, max_trials, tuner_dir, and project_name.
Search Execution: Iterates through batch_sizes. For each: * Runs tuner.search() using the data, epochs (for trials),
validation_split, and callbacks.
Retrieves the best hyperparameters for that batch size.
Builds and fully trains a model using these HPs and batch size for the user-specified epochs.
Best Model Selection: Compares validation loss across all tested batch_sizes to find the overall best_hps, best_model, and best_batch_size.
Output: Returns (best_hps, best_model, tuner_object). Results are logged to a JSON file.
Usage Context: Use after preparing training data into the required list format. Provide data, forecast_horizon, quantiles (if any), and optionally customize param_space, max_trials, epochs, etc. Crucially, set model_name to “xtft”, “superxtft”, “tft”, or “tft_flex” to guide the internal model builder.
Code Example (Tuning XTFT):
1import numpy as np
2import os
3import tensorflow as tf
4from fusionlab.nn.forecast_tuner import xtft_tuner
5# from fusionlab.nn import XTFT # For context
6
7# 1. Prepare Dummy Data (Static, Dynamic, Future)
8B, T_past, H_out = 8, 12, 6
9D_s, D_d, D_f = 3, 5, 2
10T_future_total = T_past + H_out
11
12X_static_train = np.random.rand(B, D_s).astype(np.float32)
13X_dynamic_train = np.random.rand(B, T_past, D_d).astype(np.float32)
14X_future_train = np.random.rand(
15 B, T_future_total, D_f).astype(np.float32)
16y_train = np.random.rand(B, H_out, 1).astype(np.float32)
17
18# Inputs for tuner: [Static, Dynamic, Future]
19train_inputs = [X_static_train, X_dynamic_train, X_future_train]
20
21# 2. Define Minimal Search Space & Case Info
22custom_param_space = {
23 'hidden_units': [16], # Fixed for speed
24 'num_heads': [2],
25 'learning_rate': [1e-3]
26}
27case_info_xtft = {
28 'quantiles': None, # Point forecast
29 'forecast_horizon': H_out,
30 'static_input_dim': D_s,
31 'dynamic_input_dim': D_d,
32 'future_input_dim': D_f,
33 'output_dim': 1
34}
35
36# 3. Define Tuning Parameters
37output_dir = "./xtft_tuning_example_output"
38project_name = "XTFT_Point_Tuning"
39
40# 4. Run the Tuner for XTFT
41print("Starting XTFT tuning...")
42best_hps, best_model, tuner = xtft_tuner(
43 inputs=train_inputs,
44 y=y_train,
45 param_space=custom_param_space,
46 forecast_horizon=H_out, # Passed directly to tuner
47 quantiles=None, # Passed directly to tuner
48 case_info=case_info_xtft, # For model builder
49 max_trials=1, # Minimal for demo
50 objective='val_loss',
51 epochs=2, # Minimal for demo
52 batch_sizes=[8], # Single small batch
53 validation_split=0.25,
54 tuner_dir=output_dir,
55 project_name=project_name,
56 tuner_type='random',
57 model_name="xtft", # Crucial: tells builder to make XTFT
58 verbose=0
59)
60
61# 5. Display Results
62print("\nXTFT Tuning complete.")
63if best_hps:
64 print("--- Best Hyperparameters (XTFT) ---")
65 print(best_hps)
66 # best_model.summary()
67else:
68 print("XTFT Tuning failed to find a best model.")
69# tuner.results_summary(num_trials=1)
tft_tuner¶
- API Reference:
Purpose:
A convenience wrapper for tuning Temporal Fusion Transformer models.
It calls xtft_tuner() internally, passing the model_name
parameter to differentiate between the stricter
TFT (which requires all static,
dynamic, and future inputs) and the more flexible
TemporalFusionTransformer (which can handle
optional static and/or future inputs).
Functionality:
Accepts the same parameters as xtft_tuner(). The key is the
model_name argument:
* Set model_name=”tft” to tune the stricter TFT class.
In this case, inputs must be a list of three non-None tensors [X_static, X_dynamic, X_future].
- Set model_name=”tft_flex” to tune the flexible
TemporalFusionTransformer. In this case, inputs can be [X_static, X_dynamic, X_future] where X_static and/or X_future can be None (or even a single tensor for dynamic-only).
The internal default model builder
(_model_builder_factory())
constructs the appropriate TFT variant and uses relevant
hyperparameters.
Usage Context: Use this when your primary goal is to tune a TFT model. Choose model_name=”tft” for the standard three-input architecture or model_name=”tft_flex” if you are working with scenarios that might not include all input types.
Code Example 1 (Tuning Stricter `TFT`):
1import numpy as np
2import os
3import tensorflow as tf
4from fusionlab.nn.forecast_tuner import tft_tuner
5# from fusionlab.nn.transformers import TFT # For context
6
7# 1. Prepare Dummy Data (ALL inputs required for stricter TFT)
8B, T_past, H_out = 8, 12, 6
9D_s, D_d, D_f = 3, 5, 2
10T_future_total = T_past + H_out
11
12X_s_train = np.random.rand(B, D_s).astype(np.float32)
13X_d_train = np.random.rand(B, T_past, D_d).astype(np.float32)
14X_f_train = np.random.rand(
15 B, T_future_total, D_f).astype(np.float32)
16y_train_tft = np.random.rand(B, H_out, 1).astype(np.float32)
17
18train_inputs_strict_tft = [X_s_train, X_d_train, X_f_train]
19
20# 2. Define Case Info & Minimal Param Space
21case_info_strict_tft = {
22 'quantiles': None, 'forecast_horizon': H_out,
23 'static_input_dim': D_s, 'dynamic_input_dim': D_d,
24 'future_input_dim': D_f, 'output_dim': 1
25}
26param_space_tft = {'hidden_units': [16], 'learning_rate': [1e-3]}
27
28# 3. Run Tuner for Stricter TFT
29print("\nStarting stricter TFT tuning...")
30best_hps_s, _, _ = tft_tuner(
31 inputs=train_inputs_strict_tft, y=y_train_tft,
32 param_space=param_space_tft,
33 forecast_horizon=H_out, quantiles=None,
34 case_info=case_info_strict_tft,
35 max_trials=1, epochs=1, batch_sizes=[4],
36 validation_split=0.5, tuner_dir="./tft_strict_tuning",
37 project_name="TFT_Strict_Tune", model_name="tft", # Key
38 verbose=0
39)
40print("Stricter TFT Tuning complete.")
41if best_hps_s: print(" Best HPs (Stricter TFT):", best_hps_s)
Code Example 2 (Tuning Flexible `TemporalFusionTransformer`):
This example tunes the flexible TFT, providing only dynamic inputs.
1import numpy as np
2import os
3import tensorflow as tf
4from fusionlab.nn.forecast_tuner import tft_tuner
5# from fusionlab.nn import TemporalFusionTransformer # For context
6
7# 1. Prepare Dummy Data (Dynamic inputs only)
8B, T_past, H_out = 8, 12, 6
9D_d = 5 # Dynamic features
10X_d_train_flex = np.random.rand(B, T_past, D_d).astype(np.float32)
11y_train_flex = np.random.rand(B, H_out, 1).astype(np.float32)
12
13# Inputs for flexible TFT (static and future are None)
14train_inputs_flex = [None, X_d_train_flex, None]
15
16# 2. Define Case Info & Minimal Param Space
17case_info_flex_tft = {
18 'quantiles': None, 'forecast_horizon': H_out,
19 'dynamic_input_dim': D_d, # Static/Future dims are None
20 'static_input_dim': None,
21 'future_input_dim': None,
22 'output_dim': 1
23}
24param_space_flex = {'hidden_units': [16], 'learning_rate': [1e-3]}
25
26# 3. Run Tuner for Flexible TFT
27print("\nStarting flexible TFT (tft_flex) tuning...")
28best_hps_f, _, _ = tft_tuner(
29 inputs=train_inputs_flex, y=y_train_flex,
30 param_space=param_space_flex,
31 forecast_horizon=H_out, quantiles=None,
32 case_info=case_info_flex_tft,
33 max_trials=1, epochs=1, batch_sizes=[4],
34 validation_split=0.5, tuner_dir="./tft_flex_tuning",
35 project_name="TFT_Flex_Tune", model_name="tft_flex", # Key
36 verbose=0
37)
38print("Flexible TFT Tuning complete.")
39if best_hps_f: print(" Best HPs (Flexible TFT):", best_hps_f)
Internal Model Builder¶
- API Reference:
_model_builder_factory()(Note: private function)
(Note: Users typically do not interact with this function directly, but understanding its role is helpful).
This internal helper function is used by default if no custom model_builder is provided to the tuner functions. Its responsibilities are:
Accepts the Keras Tuner hp object.
Determines the correct model class to instantiate (XTFT, SuperXTFT, or TemporalFusionTransformer) based on the model_name.
Defines the range or set of choices for each hyperparameter relevant to the chosen model class, using hp.Choice, hp.Boolean, etc., based on the param_space provided to the tuner or the internal DEFAULT_PS.
Instantiates the model class with the sampled hyperparameters.
Compiles the model with an Adam optimizer (learning rate is also tuned) and an appropriate loss function (MSE or quantile loss).
Returns the compiled model instance to the Keras Tuner for evaluation during the search process.
By providing a custom model_builder function to xtft_tuner or tft_tuner, users can gain finer control over the architecture variations or compilation settings explored during tuning.