. _example_hyperparameter_tuning:
TFT & XTFT Tuning Examples¶
Finding optimal hyperparameters is crucial for maximizing the
performance of complex models like
XTFT,
TFT (stricter version), and
the flexible TemporalFusionTransformer.
This example demonstrates how to use the
xtft_tuner() and
tft_tuner() utilities
from fusionlab-learn to automate this search process.
We will cover:
Tuning
XTFTfor quantile forecasting.Tuning the stricter
TFT(all inputs required) for point forecasting.Tuning the flexible
TemporalFusionTransformer(using model_name=”tft_flex”) for point forecasting, demonstrating its ability to handle optional inputs (e.g., only dynamic features).
Prerequisites¶
Ensure you have fusionlab-learn and keras-tuner installed:
pip install fusionlab-learn keras-tuner matplotlib
Common Setup for All Examples¶
The following imports and directory setup are common.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import os
5import shutil # For cleaning up tuner directories
6import warnings
7
8# FusionLab imports
9from fusionlab.datasets.make import make_multi_feature_time_series
10from fusionlab.nn.forecast_tuner import xtft_tuner, tft_tuner
11from fusionlab.nn.transformers import (
12 XTFT,
13 TFT as TFTStricter, # Alias for the stricter TFT
14 TemporalFusionTransformer as TFTFlexible # Alias for flexible TFT
15)
16from fusionlab.nn.losses import combined_quantile_loss
17from fusionlab.nn.utils import reshape_xtft_data
18import keras_tuner as kt
19
20# Suppress warnings and TF logs for cleaner output
21warnings.filterwarnings('ignore')
22tf.get_logger().setLevel('ERROR')
23if hasattr(tf, 'autograph'):
24 tf.autograph.set_verbosity(0)
25
26# Configuration for outputs
27base_output_dir_tuning = "./gallery_tuning_runs"
28if not os.path.exists(base_output_dir_tuning):
29 os.makedirs(base_output_dir_tuning, exist_ok=True)
30
31print("Libraries imported and base setup complete for tuning examples.")
Example 1: Tuning XTFT for Quantile Forecasting¶
This section demonstrates tuning XTFT.
Step 1.1: Generate Synthetic Data for XTFT¶
We use make_multi_feature_time_series()
to create a dataset with static, dynamic, and future features.
1# Data generation parameters for XTFT
2N_SERIES_XTFT = 2
3N_TIMESTEPS_XTFT = 60
4FREQ_XTFT = 'MS'
5SEED_XTFT = 42
6
7data_bunch_xtft = make_multi_feature_time_series(
8 n_series=N_SERIES_XTFT, n_timesteps=N_TIMESTEPS_XTFT,
9 freq=FREQ_XTFT, seasonality_period=12,
10 seed=SEED_XTFT, as_frame=False
11)
12df_for_xtft_tuning = data_bunch_xtft.frame
13print(f"Generated data for XTFT tuning. Shape: {df_for_xtft_tuning.shape}")
14
15# Prepare data for reshape_xtft_data (assuming numerical readiness)
16dt_col_xtft = data_bunch_xtft.dt_col
17target_col_xtft = data_bunch_xtft.target_col
18static_cols_xtft = data_bunch_xtft.static_features
19dynamic_cols_xtft = data_bunch_xtft.dynamic_features
20future_cols_xtft = data_bunch_xtft.future_features
21spatial_cols_xtft = [data_bunch_xtft.spatial_id_col]
22
23time_steps_xtft = 12
24forecast_horizon_xtft = 6
25
26s_data_xtft, d_data_xtft, f_data_xtft, t_data_xtft = reshape_xtft_data(
27 df=df_for_xtft_tuning, dt_col=dt_col_xtft,
28 target_col=target_col_xtft,
29 dynamic_cols=dynamic_cols_xtft, static_cols=static_cols_xtft,
30 future_cols=future_cols_xtft, spatial_cols=spatial_cols_xtft,
31 time_steps=time_steps_xtft,
32 forecast_horizons=forecast_horizon_xtft,
33 verbose=0
34)
35train_inputs_xtft = [
36 tf.constant(s_data_xtft, dtype=tf.float32),
37 tf.constant(d_data_xtft, dtype=tf.float32),
38 tf.constant(f_data_xtft, dtype=tf.float32)
39]
40y_train_xtft = tf.constant(t_data_xtft, dtype=tf.float32)
41print(f"XTFT Reshaped: S={s_data_xtft.shape}, D={d_data_xtft.shape}, "
42 f"F={f_data_xtft.shape}, T={t_data_xtft.shape}")
Step 1.2: Define XTFT Search Space and Case Info¶
Define quantiles, a custom search space, and fixed case_info.
1quantiles_xtft = [0.1, 0.5, 0.9]
2custom_param_space_xtft = {
3 'hidden_units': [16, 32], 'num_heads': [1, 2],
4 'lstm_units': [16], 'dropout_rate': [0.05, 0.1],
5 'learning_rate': [5e-4, 1e-3]
6}
7case_info_xtft = {
8 'quantiles': quantiles_xtft,
9 'forecast_horizon': forecast_horizon_xtft,
10 'output_dim': y_train_xtft.shape[-1],
11 'static_input_dim': train_inputs_xtft[0].shape[-1],
12 'dynamic_input_dim': train_inputs_xtft[1].shape[-1],
13 'future_input_dim': train_inputs_xtft[2].shape[-1],
14 'embed_dim': 16, 'max_window_size': time_steps_xtft,
15 'memory_size': 20, 'attention_units': 16,
16 'recurrent_dropout_rate': 0.0,
17 'use_residuals_choices': [True], 'final_agg': 'last',
18 'multi_scale_agg': 'last', 'scales_options': ['no_scales'],
19 'use_batch_norm_choices': [False], 'verbose_build': 0
20}
Step 1.3: Run the XTFT Tuner¶
1output_dir_xtft = os.path.join(base_output_dir_tuning, "xtft_run")
2project_name_xtft = "XTFT_Gallery_Quantile_Tuning"
3if os.path.exists(os.path.join(output_dir_xtft, project_name_xtft)):
4 shutil.rmtree(os.path.join(output_dir_xtft, project_name_xtft))
5
6print("\nStarting XTFT hyperparameter tuning...")
7best_hps_xtft, best_model_xtft, tuner_xtft = xtft_tuner(
8 inputs=train_inputs_xtft, y=y_train_xtft,
9 param_space=custom_param_space_xtft,
10 forecast_horizon=forecast_horizon_xtft,
11 quantiles=quantiles_xtft,
12 case_info=case_info_xtft,
13 max_trials=1, epochs=1, batch_sizes=[4], # Minimal for demo
14 validation_split=0.5,
15 tuner_dir=output_dir_xtft, project_name=project_name_xtft,
16 tuner_type='random', model_name="xtft", verbose=0
17)
18print("\nXTFT Tuning complete.")
19if best_hps_xtft:
20 print("--- Best Hyperparameters (XTFT) ---")
21 print(best_hps_xtft)
22 # if best_model_xtft: best_model_xtft.summary()
23else:
24 print("XTFT Tuning did not yield best HPs.")
Tuning Standard TFT Variants¶
This section covers tuning the stricter TFT
and the flexible TemporalFusionTransformer
(referred to as tft_flex). We use the
tft_tuner() function, which is a
wrapper around xtft_tuner(),
setting the model_name appropriately.
Tuning Stricter TFT (All Inputs Required)¶
The stricter TFT requires static,
dynamic, and future inputs to be non-None.
Step 2.1: Prepare Data for Stricter TFT¶
We use the same data generation as for XTFT, as it includes all three input types.
1# Re-use data from XTFT example (s_data_xtft, d_data_xtft, etc.)
2# Or generate new if needed, ensuring all D_s, D_d, D_f are > 0
3train_inputs_strict_tft = [
4 tf.constant(s_data_xtft, dtype=tf.float32),
5 tf.constant(d_data_xtft, dtype=tf.float32),
6 tf.constant(f_data_xtft, dtype=tf.float32)
7]
8y_train_strict_tft = tf.constant(t_data_xtft, dtype=tf.float32)
9print("\nData prepared for Stricter TFT tuning.")
Step 2.2: Define Stricter TFT Search Space and Case Info¶
The search space will focus on parameters relevant to the standard TFT.
1# Point forecast for this example
2param_space_strict_tft = {
3 'hidden_units': [16, 32],
4 'num_heads': [1, 2],
5 'num_lstm_layers': [1], # Tune number of LSTM layers
6 'lstm_units': [16, 32], # Tune LSTM units
7 'dropout_rate': [0.0, 0.1],
8 'recurrent_dropout_rate': [0.0], # Often fixed or small
9 'learning_rate': [1e-3]
10}
11case_info_strict_tft = {
12 'quantiles': None, # Point forecast
13 'forecast_horizon': forecast_horizon_xtft, # Use same as XTFT example
14 'output_dim': y_train_strict_tft.shape[-1],
15 'static_input_dim': train_inputs_strict_tft[0].shape[-1],
16 'dynamic_input_dim': train_inputs_strict_tft[1].shape[-1],
17 'future_input_dim': train_inputs_strict_tft[2].shape[-1],
18 'activation': 'relu', # Fixed activation
19 'use_batch_norm_choices': [False], # Fixed
20 'verbose_build': 0
21}
Step 2.3: Run the Tuner for Stricter TFT¶
1output_dir_strict_tft = os.path.join(base_output_dir_tuning, "tft_strict_run")
2project_name_strict_tft = "TFT_Strict_Gallery_Point_Tuning"
3if os.path.exists(os.path.join(output_dir_strict_tft, project_name_strict_tft)):
4 shutil.rmtree(os.path.join(output_dir_strict_tft, project_name_strict_tft))
5
6print("\nStarting Stricter TFT hyperparameter tuning...")
7best_hps_tft_s, _, _ = tft_tuner( # Use tft_tuner
8 inputs=train_inputs_strict_tft,
9 y=y_train_strict_tft,
10 param_space=param_space_strict_tft,
11 forecast_horizon=forecast_horizon_xtft,
12 quantiles=None,
13 case_info=case_info_strict_tft,
14 max_trials=1, epochs=1, batch_sizes=[4],
15 validation_split=0.5,
16 tuner_dir=output_dir_strict_tft,
17 project_name=project_name_strict_tft,
18 model_name="tft", # Key: specifies the stricter TFT
19 verbose=0
20)
21print("\nStricter TFT Tuning complete.")
22if best_hps_tft_s:
23 print("--- Best Hyperparameters (Stricter TFT) ---")
24 print(best_hps_tft_s)
Tuning Flexible TemporalFusionTransformer (tft_flex)¶
This demonstrates tuning the flexible
TemporalFusionTransformer
using only dynamic inputs.
Step 3.1: Prepare Data for Flexible TFT (Dynamic Only)¶
We’ll use only the dynamic part of the previously generated data.
1# Use d_data_xtft and t_data_xtft from the XTFT data prep
2# Inputs for flexible TFT: [Static, Dynamic, Future]
3# Here, Static and Future will be None.
4# rather to pass this:
5train_inputs_flex_tft = [
6 None, # No static input
7 tf.constant(d_data_xtft, dtype=tf.float32), # Only dynamic
8 None # No future input
9]
10# pass only the dynamic , and TemporalFusionTransformer will
11# handle it
12train_inputs_flex_tft = [
13 tf.constant(d_data_xtft, dtype=tf.float32), # Only dynamic
14]
15y_train_flex_tft = tf.constant(t_data_xtft, dtype=tf.float32)
16print("\nData prepared for Flexible TFT (Dynamic Only) tuning.")
17print(f" Dynamic Input Shape: {train_inputs_flex_tft[0].shape}")
Step 3.2: Define Flexible TFT Search Space and Case Info¶
The case_info will reflect that static and future dimensions are None.
1# Point forecast for this example
2param_space_flex_tft = {
3 'hidden_units': [8, 16], # Smaller search space
4 'num_heads': [1],
5 'num_lstm_layers': [1],
6 'lstm_units': [16],
7 'dropout_rate': [0.0],
8 'learning_rate': [1e-3]
9}
10case_info_flex_tft = {
11 'quantiles': None, # Point forecast
12 'forecast_horizon': forecast_horizon_xtft,
13 'output_dim': y_train_flex_tft.shape[-1],
14 'static_input_dim': None, # Explicitly None
15 'dynamic_input_dim': train_inputs_flex_tft[0].shape[-1],
16 'future_input_dim': None, # Explicitly None
17 'activation': 'relu',
18 'use_batch_norm_choices': [False],
19 'verbose_build': 0
20}
Step 3.3: Run the Tuner for Flexible TFT¶
1output_dir_flex_tft = os.path.join(base_output_dir_tuning, "tft_flex_run")
2project_name_flex_tft = "TFT_Flexible_Gallery_Point_Tuning"
3if os.path.exists(os.path.join(output_dir_flex_tft, project_name_flex_tft)):
4 shutil.rmtree(os.path.join(output_dir_flex_tft, project_name_flex_tft))
5
6print("\nStarting Flexible TFT (tft_flex) hyperparameter tuning...")
7best_hps_tft_f, _, _ = tft_tuner( # Use tft_tuner
8 inputs=train_inputs_flex_tft, # [None, Dynamic, None]
9 y=y_train_flex_tft,
10 param_space=param_space_flex_tft,
11 forecast_horizon=forecast_horizon_xtft,
12 quantiles=None,
13 case_info=case_info_flex_tft,
14 max_trials=1, epochs=1, batch_sizes=[4],
15 validation_split=0.5,
16 tuner_dir=output_dir_flex_tft,
17 project_name=project_name_flex_tft,
18 model_name="tft_flex", # Key: specifies flexible TemporalFusionTransformer
19 verbose=0
20)
21print("\nFlexible TFT (tft_flex) Tuning complete.")
22if best_hps_tft_f:
23 print("--- Best Hyperparameters (Flexible TFT) ---")
24 print(best_hps_tft_f)