TFT Forecast Tuner Guide

Finding the optimal set of hyperparameters for deep learning models like TemporalFusionTransformer, TFT (stricter version), XTFT, and SuperXTFT is crucial for achieving the best possible forecasting performance. Hyperparameters control aspects of the model architecture (e.g., number of hidden units, attention heads) and the training process (e.g., learning rate, batch size).

fusionlab provides utility functions and classes within the forecast_tuner module that leverage the powerful Keras Tuner library (keras-tuner) to automate this search process.

Prerequisites

To use the tuning functions, you must have Keras Tuner installed:

pip install keras-tuner -q

Choosing Your Tuning Approach: Functions vs. Classes

As of version 0.2.3, fusionlab-learn offers two primary ways to perform hyperparameter tuning:

  1. Class-Based Tuners (Recommended for New Projects): This modern, object-oriented approach utilizes dedicated tuner classes like XTFTTuner and TFTTuner. It offers improved structure, reusability (configure once, fit multiple times), and flexibility. For detailed information and examples, please refer to the Class-Based Forecast Tuner Guide.

  2. Function-Based Tuners (Legacy, Still Supported): The original approach uses standalone functions like xtft_tuner() and tft_tuner(). These functions remain fully supported for backward compatibility and are detailed in the sections below.

For new development, we recommend exploring the class-based tuners for a more streamlined and maintainable workflow.

Tuner Functions (Legacy Approach)

The fusionlab.nn.forecast_tuner module offers dedicated functions to tune different model types. These are suitable for quick experiments or maintaining existing codebases.

xtft_tuner

API Reference:

xtft_tuner()

Purpose: To perform hyperparameter optimization for the XTFT and SuperXTFT models. It can also be used to tune TFT (stricter) and TemporalFusionTransformer (flexible) by specifying the appropriate model_name.

Functionality: This function orchestrates the tuning process:

  1. Inputs: Takes prepared input data as a list inputs = [X_static, X_dynamic, X_future] (where static or future can be None if model_name=’tft_flex’) and the target array y.

  2. Search Space: Uses a default space (DEFAULT_PS) for common hyperparameters. Users can provide their own param_space dictionary to override or extend these.

  3. Model Builder: Employs an internal default _model_builder_factory() (or a user-provided model_builder) to construct model instances for given hyperparameters (hp). The builder samples values using Keras Tuner’s hp object. Models are compiled with Adam optimizer and an appropriate loss (MSE or quantile).

  4. Tuner Initialization: Creates a Keras Tuner instance (RandomSearch or BayesianOptimization) configured with the objective, max_trials, tuner_dir, and project_name.

  5. Search Execution: Iterates through batch_sizes. For each: * Runs tuner.search() using the data, epochs (for trials),

    validation_split, and callbacks.

    • Retrieves the best hyperparameters for that batch size.

    • Builds and fully trains a model using these HPs and batch size for the user-specified epochs.

  6. Best Model Selection: Compares validation loss across all tested batch_sizes to find the overall best_hps, best_model, and best_batch_size.

  7. Output: Returns (best_hps, best_model, tuner_object). Results are logged to a JSON file.

Usage Context: Use after preparing training data into the required list format. Provide data, forecast_horizon, quantiles (if any), and optionally customize param_space, max_trials, epochs, etc. Crucially, set model_name to “xtft”, “superxtft”, “tft”, or “tft_flex” to guide the internal model builder.

Code Example (Tuning XTFT):

 1import numpy as np
 2import os
 3import tensorflow as tf
 4from fusionlab.nn.forecast_tuner import xtft_tuner
 5# from fusionlab.nn import XTFT # For context
 6
 7# 1. Prepare Dummy Data (Static, Dynamic, Future)
 8B, T_past, H_out = 8, 12, 6
 9D_s, D_d, D_f = 3, 5, 2
10T_future_total = T_past + H_out
11
12X_static_train = np.random.rand(B, D_s).astype(np.float32)
13X_dynamic_train = np.random.rand(B, T_past, D_d).astype(np.float32)
14X_future_train = np.random.rand(
15    B, T_future_total, D_f).astype(np.float32)
16y_train = np.random.rand(B, H_out, 1).astype(np.float32)
17
18# Inputs for tuner: [Static, Dynamic, Future]
19train_inputs = [X_static_train, X_dynamic_train, X_future_train]
20
21# 2. Define Minimal Search Space & Case Info
22custom_param_space = {
23    'hidden_units': [16], # Fixed for speed
24    'num_heads': [2],
25    'learning_rate': [1e-3]
26}
27case_info_xtft = {
28    'quantiles': None, # Point forecast
29    'forecast_horizon': H_out,
30    'static_input_dim': D_s,
31    'dynamic_input_dim': D_d,
32    'future_input_dim': D_f,
33    'output_dim': 1
34}
35
36# 3. Define Tuning Parameters
37output_dir = "./xtft_tuning_example_output"
38project_name = "XTFT_Point_Tuning"
39
40# 4. Run the Tuner for XTFT
41print("Starting XTFT tuning...")
42best_hps, best_model, tuner = xtft_tuner(
43    inputs=train_inputs,
44    y=y_train,
45    param_space=custom_param_space,
46    forecast_horizon=H_out, # Passed directly to tuner
47    quantiles=None,         # Passed directly to tuner
48    case_info=case_info_xtft, # For model builder
49    max_trials=1,           # Minimal for demo
50    objective='val_loss',
51    epochs=2,               # Minimal for demo
52    batch_sizes=[8],        # Single small batch
53    validation_split=0.25,
54    tuner_dir=output_dir,
55    project_name=project_name,
56    tuner_type='random',
57    model_name="xtft", # Crucial: tells builder to make XTFT
58    verbose=0
59)
60
61# 5. Display Results
62print("\nXTFT Tuning complete.")
63if best_hps:
64    print("--- Best Hyperparameters (XTFT) ---")
65    print(best_hps)
66    # best_model.summary()
67else:
68    print("XTFT Tuning failed to find a best model.")
69# tuner.results_summary(num_trials=1)

tft_tuner

API Reference:

tft_tuner()

Purpose: A convenience wrapper for tuning Temporal Fusion Transformer models. It calls xtft_tuner() internally, passing the model_name parameter to differentiate between the stricter TFT (which requires all static, dynamic, and future inputs) and the more flexible TemporalFusionTransformer (which can handle optional static and/or future inputs).

Functionality: Accepts the same parameters as xtft_tuner(). The key is the model_name argument: * Set model_name=”tft” to tune the stricter TFT class.

In this case, inputs must be a list of three non-None tensors [X_static, X_dynamic, X_future].

  • Set model_name=”tft_flex” to tune the flexible

    TemporalFusionTransformer. In this case, inputs can be [X_static, X_dynamic, X_future] where X_static and/or X_future can be None (or even a single tensor for dynamic-only).

The internal default model builder (_model_builder_factory()) constructs the appropriate TFT variant and uses relevant hyperparameters.

Usage Context: Use this when your primary goal is to tune a TFT model. Choose model_name=”tft” for the standard three-input architecture or model_name=”tft_flex” if you are working with scenarios that might not include all input types.

Code Example 1 (Tuning Stricter `TFT`):

 1import numpy as np
 2import os
 3import tensorflow as tf
 4from fusionlab.nn.forecast_tuner import tft_tuner
 5# from fusionlab.nn.transformers import TFT # For context
 6
 7# 1. Prepare Dummy Data (ALL inputs required for stricter TFT)
 8B, T_past, H_out = 8, 12, 6
 9D_s, D_d, D_f = 3, 5, 2
10T_future_total = T_past + H_out
11
12X_s_train = np.random.rand(B, D_s).astype(np.float32)
13X_d_train = np.random.rand(B, T_past, D_d).astype(np.float32)
14X_f_train = np.random.rand(
15    B, T_future_total, D_f).astype(np.float32)
16y_train_tft = np.random.rand(B, H_out, 1).astype(np.float32)
17
18train_inputs_strict_tft = [X_s_train, X_d_train, X_f_train]
19
20# 2. Define Case Info & Minimal Param Space
21case_info_strict_tft = {
22    'quantiles': None, 'forecast_horizon': H_out,
23    'static_input_dim': D_s, 'dynamic_input_dim': D_d,
24    'future_input_dim': D_f, 'output_dim': 1
25}
26param_space_tft = {'hidden_units': [16], 'learning_rate': [1e-3]}
27
28# 3. Run Tuner for Stricter TFT
29print("\nStarting stricter TFT tuning...")
30best_hps_s, _, _ = tft_tuner(
31    inputs=train_inputs_strict_tft, y=y_train_tft,
32    param_space=param_space_tft,
33    forecast_horizon=H_out, quantiles=None,
34    case_info=case_info_strict_tft,
35    max_trials=1, epochs=1, batch_sizes=[4],
36    validation_split=0.5, tuner_dir="./tft_strict_tuning",
37    project_name="TFT_Strict_Tune", model_name="tft", # Key
38    verbose=0
39)
40print("Stricter TFT Tuning complete.")
41if best_hps_s: print("  Best HPs (Stricter TFT):", best_hps_s)

Code Example 2 (Tuning Flexible `TemporalFusionTransformer`):

This example tunes the flexible TFT, providing only dynamic inputs.

 1import numpy as np
 2import os
 3import tensorflow as tf
 4from fusionlab.nn.forecast_tuner import tft_tuner
 5# from fusionlab.nn import TemporalFusionTransformer # For context
 6
 7# 1. Prepare Dummy Data (Dynamic inputs only)
 8B, T_past, H_out = 8, 12, 6
 9D_d = 5 # Dynamic features
10X_d_train_flex = np.random.rand(B, T_past, D_d).astype(np.float32)
11y_train_flex = np.random.rand(B, H_out, 1).astype(np.float32)
12
13# Inputs for flexible TFT (static and future are None)
14train_inputs_flex = [None, X_d_train_flex, None]
15
16# 2. Define Case Info & Minimal Param Space
17case_info_flex_tft = {
18    'quantiles': None, 'forecast_horizon': H_out,
19    'dynamic_input_dim': D_d, # Static/Future dims are None
20    'static_input_dim': None,
21    'future_input_dim': None,
22    'output_dim': 1
23}
24param_space_flex = {'hidden_units': [16], 'learning_rate': [1e-3]}
25
26# 3. Run Tuner for Flexible TFT
27print("\nStarting flexible TFT (tft_flex) tuning...")
28best_hps_f, _, _ = tft_tuner(
29    inputs=train_inputs_flex, y=y_train_flex,
30    param_space=param_space_flex,
31    forecast_horizon=H_out, quantiles=None,
32    case_info=case_info_flex_tft,
33    max_trials=1, epochs=1, batch_sizes=[4],
34    validation_split=0.5, tuner_dir="./tft_flex_tuning",
35    project_name="TFT_Flex_Tune", model_name="tft_flex", # Key
36    verbose=0
37)
38print("Flexible TFT Tuning complete.")
39if best_hps_f: print("  Best HPs (Flexible TFT):", best_hps_f)

Internal Model Builder

API Reference:

_model_builder_factory() (Note: private function)

(Note: Users typically do not interact with this function directly, but understanding its role is helpful).

This internal helper function is used by default if no custom model_builder is provided to the tuner functions. Its responsibilities are:

  1. Accepts the Keras Tuner hp object.

  2. Determines the correct model class to instantiate (XTFT, SuperXTFT, or TemporalFusionTransformer) based on the model_name.

  3. Defines the range or set of choices for each hyperparameter relevant to the chosen model class, using hp.Choice, hp.Boolean, etc., based on the param_space provided to the tuner or the internal DEFAULT_PS.

  4. Instantiates the model class with the sampled hyperparameters.

  5. Compiles the model with an Adam optimizer (learning rate is also tuned) and an appropriate loss function (MSE or quantile loss).

  6. Returns the compiled model instance to the Keras Tuner for evaluation during the search process.

By providing a custom model_builder function to xtft_tuner or tft_tuner, users can gain finer control over the architecture variations or compilation settings explored during tuning.