XTFT Tuning

Finding optimal hyperparameters is crucial for maximizing the performance of complex models like XTFT. This example demonstrates how to use the xtft_tuner() utility from fusionlab-learn to automate this search process for an XTFT model configured for quantile forecasting.

The workflow includes:

  1. Generating synthetic multi-feature time series data suitable for XTFT.

  2. Defining a hyperparameter search space.

  3. Configuring and running the xtft_tuner.

  4. Retrieving and inspecting the best hyperparameters and model.

Prerequisites

Ensure you have fusionlab-learn and keras-tuner installed:

pip install fusionlab-learn keras-tuner matplotlib

Step 1: Imports and Setup

Import necessary libraries, including fusionlab-learn components for data generation, model tuning, and the XTFT model itself.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import os
 5import shutil # For cleaning up tuner directories
 6import warnings
 7
 8# FusionLab imports
 9from fusionlab.datasets.make import make_multi_feature_time_series
10from fusionlab.nn.forecast_tuner import xtft_tuner
11from fusionlab.nn.transformers import XTFT # For type context
12from fusionlab.nn.losses import combined_quantile_loss # For loss definition
13from fusionlab.nn.utils import reshape_xtft_data
14import keras_tuner as kt # Keras Tuner
15
16# Suppress warnings and TF logs for cleaner output
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'):
20    tf.autograph.set_verbosity(0)
21
22# Configuration for outputs
23output_dir_tuning = "./gallery_tuning_output"
24# Clean up previous run if it exists, for a fresh start
25if os.path.exists(output_dir_tuning):
26    shutil.rmtree(output_dir_tuning)
27os.makedirs(output_dir_tuning, exist_ok=True)
28
29print("Libraries imported and setup complete for tuning example.")

Step 2: Generate Synthetic Data for XTFT

We use make_multi_feature_time_series() to create a dataset with static, dynamic, and future features. This data will be used for training and validation during the tuning process.

 1# Data generation parameters
 2N_SERIES_TUNE = 2
 3N_TIMESTEPS_TUNE = 60 # Approx 5 years of monthly data
 4FREQ_TUNE = 'MS'
 5SEED_TUNE = 42
 6
 7# Generate data as a Bunch object
 8data_bunch = make_multi_feature_time_series(
 9    n_series=N_SERIES_TUNE,
10    n_timesteps=N_TIMESTEPS_TUNE,
11    freq=FREQ_TUNE,
12    seasonality_period=12, # Yearly seasonality for monthly data
13    seed=SEED_TUNE,
14    as_frame=False # Get Bunch to easily access column names
15)
16df_for_tuning = data_bunch.frame
17print(f"Generated data for tuning. Shape: {df_for_tuning.shape}")
18
19# --- Prepare data for reshape_xtft_data ---
20# This step would normally involve scaling, encoding etc.
21# For this example, we assume data is numerically ready.
22# In a real workflow, use load_processed_subsidence_data or similar.
23
24dt_col_tune = data_bunch.dt_col
25target_col_tune = data_bunch.target_col
26static_cols_tune = data_bunch.static_features
27dynamic_cols_tune = data_bunch.dynamic_features
28future_cols_tune = data_bunch.future_features
29spatial_cols_tune = [data_bunch.spatial_id_col]
30
31# Reshape data into sequences
32
33
34time_steps_tune = 12 # 1 year lookback
35forecast_horizon_tune = 6 # Predict 6 months ahead
36
37s_data, d_data, f_data, t_data = reshape_xtft_data(
38    df=df_for_tuning, dt_col=dt_col_tune, target_col=target_col_tune,
39    dynamic_cols=dynamic_cols_tune, static_cols=static_cols_tune,
40    future_cols=future_cols_tune, spatial_cols=spatial_cols_tune,
41    time_steps=time_steps_tune, forecast_horizons=forecast_horizon_tune,
42    verbose=0
43)
44print(f"\nReshaped data for tuning:")
45print(f"  Static : {s_data.shape}, Dynamic: {d_data.shape}")
46print(f"  Future : {f_data.shape}, Target : {t_data.shape}")
47
48# For tuner, inputs are [Static, Dynamic, Future]
49# All inputs are required by XTFT
50if s_data is None or d_data is None or f_data is None:
51    raise ValueError("XTFT requires static, dynamic, and future inputs.")
52
53train_inputs_tune = [
54    tf.constant(s_data, dtype=tf.float32),
55    tf.constant(d_data, dtype=tf.float32),
56    tf.constant(f_data, dtype=tf.float32)
57]
58y_train_tune = tf.constant(t_data, dtype=tf.float32)

Step 3: Define Hyperparameter Search Space and Case Info

We define a custom_param_space to explore a few hyperparameters. case_info provides fixed parameters required by the model builder.

 1# Define quantiles for probabilistic forecast
 2quantiles_tune = [0.1, 0.5, 0.9]
 3
 4# Custom search space (subset of DEFAULT_PS in forecast_tuner)
 5custom_param_space_tune = {
 6    'hidden_units': [16, 32],       # Try these hidden unit sizes
 7    'num_heads': [1, 2],            # Try 1 or 2 attention heads
 8    'lstm_units': [16],             # Fix LSTM units for this demo
 9    'dropout_rate': [0.05, 0.1],
10    'learning_rate': [5e-4, 1e-3] # Try two learning rates
11}
12
13# Case info provides fixed parameters for the model builder
14# It must include all required dimensions for the model
15case_info_tune = {
16    'quantiles': quantiles_tune,
17    'forecast_horizon': forecast_horizon_tune,
18    'output_dim': y_train_tune.shape[-1], # Should be 1 for this example
19    'static_input_dim': train_inputs_tune[0].shape[-1],
20    'dynamic_input_dim': train_inputs_tune[1].shape[-1],
21    'future_input_dim': train_inputs_tune[2].shape[-1],
22    # Pass other fixed XTFT params if not tuning them:
23    'embed_dim': 16, # Example fixed value
24    'max_window_size': time_steps_tune,
25    'memory_size': 20,
26    'attention_units': 16,
27    'recurrent_dropout_rate': 0.0,
28    'use_residuals_choices': [True], # Fix use_residuals to True
29    'final_agg': 'last',
30    'multi_scale_agg': 'last',
31    'scales_options': ['no_scales'], # Fix scales to None
32    'use_batch_norm_choices': [False], # Fix use_batch_norm
33    'verbose_build': 0 # Suppress model builder logs
34}
35print("\nHyperparameter search space and case info defined.")

Step 4: Run the XTFT Tuner

Call xtft_tuner() with the prepared data, search space, and tuning configurations. We use a small number of max_trials and epochs for a quick demonstration.

 1project_name_tune = "XTFT_Gallery_Quantile_Tuning"
 2# Clean up previous project directory if it exists
 3project_path = os.path.join(output_dir_tuning, project_name_tune)
 4if os.path.exists(project_path):
 5    shutil.rmtree(project_path)
 6
 7print("\nStarting XTFT hyperparameter tuning...")
 8best_hps, best_model, tuner = xtft_tuner(
 9    inputs=train_inputs_tune,
10    y=y_train_tune,
11    param_space=custom_param_space_tune,
12    # forecast_horizon and quantiles are now primarily passed via case_info
13    # for the model builder, but also needed by tuner func for defaults
14    forecast_horizon=forecast_horizon_tune,
15    quantiles=quantiles_tune,
16    case_info=case_info_tune, # Crucial for model instantiation
17    max_trials=2,        # Number of HP combinations to try per batch size
18    objective='val_loss',
19    epochs=3,            # Epochs for FULL training of best HP per batch
20    batch_sizes=[8],     # Test with a single small batch size for demo
21    validation_split=0.3, # Use 30% of data for validation during search
22    tuner_dir=output_dir_tuning,
23    project_name=project_name_tune,
24    tuner_type='random', # 'random' or 'bayesian'
25    model_name="xtft",   # Specify XTFT for the default builder
26    # ; change to model_name='super_xtft', for SuperXFT tuning
27    verbose=1            # Show some tuner progress
28)
29print("\nXTFT Tuning complete.")

Step 5: Display Results

The tuner returns the best hyperparameters found, the corresponding fully trained model, and the Keras Tuner object.

 1if best_hps:
 2    print("\n--- Best Hyperparameters Found ---")
 3    for param, value in best_hps.items():
 4        print(f"  {param}: {value}")
 5    print(f"\nOptimal Batch Size (among tested): "
 6          f"{best_hps.get('batch_size', 'N/A')}")
 7
 8    print("\n--- Summary of Best Model Architecture ---")
 9    if best_model:
10        best_model.summary(line_length=100)
11    else:
12        print("Best model was not returned from tuning.")
13else:
14    print("Tuning did not yield best hyperparameters (e.g., all trials failed).")
15
16# For more details, you can inspect the tuner object:
17# if tuner:
18#     tuner.results_summary()
19
20# The `best_model` can now be used for forecasting or saved.