XTFT Tuning¶
Finding optimal hyperparameters is crucial for maximizing the
performance of complex models like
XTFT. This example demonstrates how to
use the xtft_tuner() utility
from fusionlab-learn to automate this search process for an
XTFT model configured for quantile forecasting.
The workflow includes:
Generating synthetic multi-feature time series data suitable for XTFT.
Defining a hyperparameter search space.
Configuring and running the xtft_tuner.
Retrieving and inspecting the best hyperparameters and model.
Prerequisites¶
Ensure you have fusionlab-learn and keras-tuner installed:
pip install fusionlab-learn keras-tuner matplotlib
Step 1: Imports and Setup¶
Import necessary libraries, including fusionlab-learn components
for data generation, model tuning, and the XTFT model itself.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import os
5import shutil # For cleaning up tuner directories
6import warnings
7
8# FusionLab imports
9from fusionlab.datasets.make import make_multi_feature_time_series
10from fusionlab.nn.forecast_tuner import xtft_tuner
11from fusionlab.nn.transformers import XTFT # For type context
12from fusionlab.nn.losses import combined_quantile_loss # For loss definition
13from fusionlab.nn.utils import reshape_xtft_data
14import keras_tuner as kt # Keras Tuner
15
16# Suppress warnings and TF logs for cleaner output
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'):
20 tf.autograph.set_verbosity(0)
21
22# Configuration for outputs
23output_dir_tuning = "./gallery_tuning_output"
24# Clean up previous run if it exists, for a fresh start
25if os.path.exists(output_dir_tuning):
26 shutil.rmtree(output_dir_tuning)
27os.makedirs(output_dir_tuning, exist_ok=True)
28
29print("Libraries imported and setup complete for tuning example.")
Step 2: Generate Synthetic Data for XTFT¶
We use make_multi_feature_time_series()
to create a dataset with static, dynamic, and future features. This
data will be used for training and validation during the tuning process.
1# Data generation parameters
2N_SERIES_TUNE = 2
3N_TIMESTEPS_TUNE = 60 # Approx 5 years of monthly data
4FREQ_TUNE = 'MS'
5SEED_TUNE = 42
6
7# Generate data as a Bunch object
8data_bunch = make_multi_feature_time_series(
9 n_series=N_SERIES_TUNE,
10 n_timesteps=N_TIMESTEPS_TUNE,
11 freq=FREQ_TUNE,
12 seasonality_period=12, # Yearly seasonality for monthly data
13 seed=SEED_TUNE,
14 as_frame=False # Get Bunch to easily access column names
15)
16df_for_tuning = data_bunch.frame
17print(f"Generated data for tuning. Shape: {df_for_tuning.shape}")
18
19# --- Prepare data for reshape_xtft_data ---
20# This step would normally involve scaling, encoding etc.
21# For this example, we assume data is numerically ready.
22# In a real workflow, use load_processed_subsidence_data or similar.
23
24dt_col_tune = data_bunch.dt_col
25target_col_tune = data_bunch.target_col
26static_cols_tune = data_bunch.static_features
27dynamic_cols_tune = data_bunch.dynamic_features
28future_cols_tune = data_bunch.future_features
29spatial_cols_tune = [data_bunch.spatial_id_col]
30
31# Reshape data into sequences
32
33
34time_steps_tune = 12 # 1 year lookback
35forecast_horizon_tune = 6 # Predict 6 months ahead
36
37s_data, d_data, f_data, t_data = reshape_xtft_data(
38 df=df_for_tuning, dt_col=dt_col_tune, target_col=target_col_tune,
39 dynamic_cols=dynamic_cols_tune, static_cols=static_cols_tune,
40 future_cols=future_cols_tune, spatial_cols=spatial_cols_tune,
41 time_steps=time_steps_tune, forecast_horizons=forecast_horizon_tune,
42 verbose=0
43)
44print(f"\nReshaped data for tuning:")
45print(f" Static : {s_data.shape}, Dynamic: {d_data.shape}")
46print(f" Future : {f_data.shape}, Target : {t_data.shape}")
47
48# For tuner, inputs are [Static, Dynamic, Future]
49# All inputs are required by XTFT
50if s_data is None or d_data is None or f_data is None:
51 raise ValueError("XTFT requires static, dynamic, and future inputs.")
52
53train_inputs_tune = [
54 tf.constant(s_data, dtype=tf.float32),
55 tf.constant(d_data, dtype=tf.float32),
56 tf.constant(f_data, dtype=tf.float32)
57]
58y_train_tune = tf.constant(t_data, dtype=tf.float32)
Step 3: Define Hyperparameter Search Space and Case Info¶
We define a custom_param_space to explore a few hyperparameters. case_info provides fixed parameters required by the model builder.
1# Define quantiles for probabilistic forecast
2quantiles_tune = [0.1, 0.5, 0.9]
3
4# Custom search space (subset of DEFAULT_PS in forecast_tuner)
5custom_param_space_tune = {
6 'hidden_units': [16, 32], # Try these hidden unit sizes
7 'num_heads': [1, 2], # Try 1 or 2 attention heads
8 'lstm_units': [16], # Fix LSTM units for this demo
9 'dropout_rate': [0.05, 0.1],
10 'learning_rate': [5e-4, 1e-3] # Try two learning rates
11}
12
13# Case info provides fixed parameters for the model builder
14# It must include all required dimensions for the model
15case_info_tune = {
16 'quantiles': quantiles_tune,
17 'forecast_horizon': forecast_horizon_tune,
18 'output_dim': y_train_tune.shape[-1], # Should be 1 for this example
19 'static_input_dim': train_inputs_tune[0].shape[-1],
20 'dynamic_input_dim': train_inputs_tune[1].shape[-1],
21 'future_input_dim': train_inputs_tune[2].shape[-1],
22 # Pass other fixed XTFT params if not tuning them:
23 'embed_dim': 16, # Example fixed value
24 'max_window_size': time_steps_tune,
25 'memory_size': 20,
26 'attention_units': 16,
27 'recurrent_dropout_rate': 0.0,
28 'use_residuals_choices': [True], # Fix use_residuals to True
29 'final_agg': 'last',
30 'multi_scale_agg': 'last',
31 'scales_options': ['no_scales'], # Fix scales to None
32 'use_batch_norm_choices': [False], # Fix use_batch_norm
33 'verbose_build': 0 # Suppress model builder logs
34}
35print("\nHyperparameter search space and case info defined.")
Step 4: Run the XTFT Tuner¶
Call xtft_tuner() with the prepared
data, search space, and tuning configurations. We use a small number
of max_trials and epochs for a quick demonstration.
1project_name_tune = "XTFT_Gallery_Quantile_Tuning"
2# Clean up previous project directory if it exists
3project_path = os.path.join(output_dir_tuning, project_name_tune)
4if os.path.exists(project_path):
5 shutil.rmtree(project_path)
6
7print("\nStarting XTFT hyperparameter tuning...")
8best_hps, best_model, tuner = xtft_tuner(
9 inputs=train_inputs_tune,
10 y=y_train_tune,
11 param_space=custom_param_space_tune,
12 # forecast_horizon and quantiles are now primarily passed via case_info
13 # for the model builder, but also needed by tuner func for defaults
14 forecast_horizon=forecast_horizon_tune,
15 quantiles=quantiles_tune,
16 case_info=case_info_tune, # Crucial for model instantiation
17 max_trials=2, # Number of HP combinations to try per batch size
18 objective='val_loss',
19 epochs=3, # Epochs for FULL training of best HP per batch
20 batch_sizes=[8], # Test with a single small batch size for demo
21 validation_split=0.3, # Use 30% of data for validation during search
22 tuner_dir=output_dir_tuning,
23 project_name=project_name_tune,
24 tuner_type='random', # 'random' or 'bayesian'
25 model_name="xtft", # Specify XTFT for the default builder
26 # ; change to model_name='super_xtft', for SuperXFT tuning
27 verbose=1 # Show some tuner progress
28)
29print("\nXTFT Tuning complete.")
Step 5: Display Results¶
The tuner returns the best hyperparameters found, the corresponding fully trained model, and the Keras Tuner object.
1if best_hps:
2 print("\n--- Best Hyperparameters Found ---")
3 for param, value in best_hps.items():
4 print(f" {param}: {value}")
5 print(f"\nOptimal Batch Size (among tested): "
6 f"{best_hps.get('batch_size', 'N/A')}")
7
8 print("\n--- Summary of Best Model Architecture ---")
9 if best_model:
10 best_model.summary(line_length=100)
11 else:
12 print("Best model was not returned from tuning.")
13else:
14 print("Tuning did not yield best hyperparameters (e.g., all trials failed).")
15
16# For more details, you can inspect the tuner object:
17# if tuner:
18# tuner.results_summary()
19
20# The `best_model` can now be used for forecasting or saved.