Exercise: Advanced Forecasting with SuperXTFT

Welcome to this exercise on using the advanced SuperXTFT model from fusionlab-learn. SuperXTFT enhances the powerful XTFT architecture by incorporating additional components, such as input Variable Selection Networks (VSNs) and post-processing Gated Residual Networks (GRNs), to maximize representation learning.

Note

SuperXTFT is a production-ready model and represents the most powerful, feature-rich version in the TFT family. It is the recommended choice when you are aiming for maximum predictive performance and have the computational resources for a deeper, more parameter-rich model.

Learning Objectives:

  • Leverage the advanced SuperXTFT model for a complex forecasting task.

  • Understand that the data preparation workflow is identical to that for XTFT.

  • Perform a multi-step quantile forecast to generate probabilistic predictions.

  • Visualize the forecast results, including their uncertainty intervals.

Let’s begin!

Prerequisites

Ensure you have fusionlab-learn and its common dependencies installed.

pip install fusionlab-learn matplotlib scikit-learn joblib

Step 1: Imports and Setup

We start by importing necessary libraries and fusionlab components.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5from sklearn.model_selection import train_test_split
 6from sklearn.preprocessing import StandardScaler
 7import os
 8import joblib
 9import warnings
10
11# FusionLab imports
12from fusionlab.nn.transformers import SuperXTFT
13from fusionlab.nn.utils import reshape_xtft_data
14from fusionlab.nn.losses import combined_quantile_loss
15from fusionlab.datasets.make import make_multi_feature_time_series
16
17# Suppress warnings and TF logs
18warnings.filterwarnings('ignore')
19os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
20tf.get_logger().setLevel('ERROR')
21if hasattr(tf, 'autograph'):
22    tf.autograph.set_verbosity(0)
23
24exercise_output_dir_super_xtft = "./super_xtft_exercise_outputs"
25os.makedirs(exercise_output_dir_super_xtft, exist_ok=True)
26
27print("Libraries imported for SuperXTFT exercise.")

Expected Output 1.1:

Libraries imported for SuperXTFT exercise.

Step 2: Generate Synthetic Multi-Feature Data

We’ll use the same data generation setup as the advanced XTFT exercise, as SuperXTFT also expects static, dynamic, and future inputs.

 1n_items_sxtft = 2
 2n_timesteps_sxtft = 36 # Shorter for quicker run
 3rng_seed_sxtft = 42
 4np.random.seed(rng_seed_sxtft)
 5tf.random.set_seed(rng_seed_sxtft)
 6
 7data_bunch_sxtft = make_multi_feature_time_series(
 8    n_series=n_items_sxtft, n_timesteps=n_timesteps_sxtft,
 9    freq='MS', seasonality_period=12,
10    seed=rng_seed_sxtft, as_frame=False
11)
12df_raw_sxtft = data_bunch_sxtft.frame.copy()
13print(f"Generated raw data shape for SuperXTFT exercise: {df_raw_sxtft.shape}")
14print(df_raw_sxtft.head(3))

Expected Output 2.2:

Generated raw data shape for SuperXTFT exercise: (72, 9)
        date  series_id  base_level  ...  month  future_event     target
0 2020-01-01          0   50.049671  ...      1             1  63.055435
1 2020-02-01          0   50.049671  ...      2             1  68.394497
2 2020-03-01          0   50.049671  ...      3             1  70.075474

[3 rows x 9 columns]

Step 3: Define Feature Roles and Scale Numerical Data

We use feature lists from the data_bunch and scale numerical features. series_id is numerical and will be used as a static feature.

 1target_col_sxtft = data_bunch_sxtft.target_col
 2dt_col_sxtft = data_bunch_sxtft.dt_col
 3static_cols_sxtft = data_bunch_sxtft.static_features
 4dynamic_cols_sxtft = data_bunch_sxtft.dynamic_features
 5future_cols_sxtft = data_bunch_sxtft.future_features
 6spatial_cols_sxtft = [data_bunch_sxtft.spatial_id_col]
 7
 8scalers_sxtft = {}
 9num_cols_to_scale_sxtft = ['base_level', 'dynamic_cov',
10                           'target_lag1', target_col_sxtft]
11df_scaled_sxtft = df_raw_sxtft.copy()
12
13for col in num_cols_to_scale_sxtft:
14    if col in df_scaled_sxtft.columns and \
15       pd.api.types.is_numeric_dtype(df_scaled_sxtft[col]):
16        scaler = StandardScaler()
17        df_scaled_sxtft[col] = scaler.fit_transform(df_scaled_sxtft[[col]])
18        scalers_sxtft[col] = scaler
19print(f"\nNumerical features scaled: {num_cols_to_scale_sxtft}")

Expected Output 3.3:

Numerical features scaled: ['base_level', 'dynamic_cov', 'target_lag1', 'target']

Step 4: Prepare Sequences using reshape_xtft_data

Transform the DataFrame into structured arrays for SuperXTFT.

 1time_steps_sxtft = 12
 2forecast_horizons_sxtft = 6
 3
 4s_data_sxtft, d_data_sxtft, f_data_sxtft, t_data_sxtft = \
 5    reshape_xtft_data(
 6        df=df_scaled_sxtft, dt_col=dt_col_sxtft,
 7        target_col=target_col_sxtft,
 8        dynamic_cols=dynamic_cols_sxtft,
 9        static_cols=static_cols_sxtft, # Includes series_id, base_level
10        future_cols=future_cols_sxtft,
11        spatial_cols=spatial_cols_sxtft,
12        time_steps=time_steps_sxtft,
13        forecast_horizons=forecast_horizons_sxtft,
14        verbose=0 # Suppress reshape logs for brevity
15    )
16print(f"\nReshaped Data Shapes for SuperXTFT:")
17print(f"  Static : {s_data_sxtft.shape}")
18print(f"  Dynamic: {d_data_sxtft.shape}")
19print(f"  Future : {f_data_sxtft.shape}")
20print(f"  Target : {t_data_sxtft.shape}")
Expected Output 4.4:

(For N_series=2, N_timesteps=36, T=12, H=6: Seq/series = 36-12-6+1 = 19. Total = 2*19 = 38)

Reshaped Data Shapes for SuperXTFT:
  Static : (38, 2)
  Dynamic: (38, 12, 4)
  Future : (38, 18, 3)
  Target : (38, 6, 1)

Step 5: Train/Validation Split of Sequences

Split sequence arrays for training and validation.

 1val_split_sxtft_frac = 0.25 # Using a bit more for validation
 2n_samples_sxtft = s_data_sxtft.shape[0]
 3split_idx_sxtft = int(n_samples_sxtft * (1 - val_split_sxtft_frac))
 4
 5X_s_train_sxtft, X_s_val_sxtft = s_data_sxtft[:split_idx_sxtft], s_data_sxtft[split_idx_sxtft:]
 6X_d_train_sxtft, X_d_val_sxtft = d_data_sxtft[:split_idx_sxtft], d_data_sxtft[split_idx_sxtft:]
 7X_f_train_sxtft, X_f_val_sxtft = f_data_sxtft[:split_idx_sxtft], f_data_sxtft[split_idx_sxtft:]
 8y_t_train_sxtft, y_t_val_sxtft = t_data_sxtft[:split_idx_sxtft], t_data_sxtft[split_idx_sxtft:]
 9
10train_inputs_sxtft = [X_s_train_sxtft, X_d_train_sxtft, X_f_train_sxtft]
11val_inputs_sxtft = [X_s_val_sxtft, X_d_val_sxtft, X_f_val_sxtft]
12
13print(f"\nData split for SuperXTFT. Train: {len(y_t_train_sxtft)}, "
14      f"Val: {len(y_t_val_sxtft)}")

Expected Output 5.5:

Data split for SuperXTFT. Train: 28, Val: 10

Step 6: Define SuperXTFT Model for Quantile Forecast

Instantiate the SuperXTFT model. Its parameters are similar to XTFT. We’ll explicitly disable anomaly detection for this exercise.

 1quantiles_sxtft = [0.1, 0.5, 0.9]
 2output_dim_sxtft = 1
 3
 4s_dim_sxtft = X_s_train_sxtft.shape[-1]
 5d_dim_sxtft = X_d_train_sxtft.shape[-1]
 6f_dim_sxtft = X_f_train_sxtft.shape[-1]
 7
 8super_xtft_model_ex = SuperXTFT(
 9    static_input_dim=s_dim_sxtft,
10    dynamic_input_dim=d_dim_sxtft,
11    future_input_dim=f_dim_sxtft,
12    forecast_horizon=forecast_horizons_sxtft,
13    quantiles=quantiles_sxtft,
14    output_dim=output_dim_sxtft,
15    # Minimal HPs for faster demo
16    embed_dim=8, lstm_units=16, attention_units=8,
17    hidden_units=16, num_heads=1, dropout_rate=0.0,
18    max_window_size=time_steps_sxtft, memory_size=10,
19    scales=None,
20    anomaly_detection_strategy=None, # Explicitly disable
21    anomaly_loss_weight=0.0
22)
23print("\nSuperXTFT model instantiated (anomaly detection disabled).")

Step 7: Compile and Train the SuperXTFT Model

Compile with quantile loss and train for a few epochs.

 1loss_fn_sxtft = combined_quantile_loss(quantiles=quantiles_sxtft)
 2super_xtft_model_ex.compile(
 3    optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
 4    loss=loss_fn_sxtft
 5    )
 6print("SuperXTFT model compiled.")
 7
 8# Optional: Build model with dummy inputs to print summary
 9# try:
10#     dummy_s = tf.zeros((1, s_dim_sxtft))
11#     dummy_d = tf.zeros((1, time_steps_sxtft, d_dim_sxtft))
12#     dummy_f = tf.zeros((1, time_steps_sxtft + forecast_horizons_sxtft, f_dim_sxtft))
13#     super_xtft_model_ex([dummy_s, dummy_d, dummy_f])
14#     super_xtft_model_ex.summary(line_length=90)
15# except Exception as e:
16#     print(f"Model build/summary error: {e}")
17
18print("\nStarting SuperXTFT model training...")
19history_sxtft = super_xtft_model_ex.fit(
20    train_inputs_sxtft, y_t_train_sxtft,
21    validation_data=(val_inputs_sxtft, y_t_val_sxtft),
22    epochs=3, batch_size=4, verbose=1 # Short run for demo
23)
24print("SuperXTFT Training finished.")
25if history_sxtft and history_sxtft.history.get('val_loss'):
26    val_loss_sxtft = history_sxtft.history['val_loss'][-1]
27    print(f"Final validation loss: {val_loss_sxtft:.4f}")
Expected Output 7.7:

(Keras training logs and final validation loss)

SuperXTFT model compiled.

Starting SuperXTFT model training...
Epoch 1/3
7/7 [==============================] - 17s 329ms/step - loss: 0.4341 - val_loss: 0.5377
Epoch 2/3
7/7 [==============================] - 0s 12ms/step - loss: 0.4233 - val_loss: 0.5354
Epoch 3/3
7/7 [==============================] - 0s 12ms/step - loss: 0.4135 - val_loss: 0.5387
SuperXTFT Training finished.
Final validation loss: 0.5387

Step 8: Make Predictions and Visualize

Predict on the validation set and visualize the quantile forecast for a sample item, similar to the XTFT example.

 1print("\nMaking quantile predictions with SuperXTFT...")
 2val_pred_scaled_sxtft = super_xtft_model_ex.predict(
 3    val_inputs_sxtft, verbose=0
 4    )
 5print(f"Scaled prediction output shape: {val_pred_scaled_sxtft.shape}")
 6
 7# Inverse Transform (simplified, assumes target was scaled)
 8target_scaler_sxtft = scalers_sxtft.get(target_col_sxtft)
 9if target_scaler_sxtft:
10    num_val_sxtft = X_s_val_sxtft.shape[0]
11    num_q_sxtft = len(quantiles_sxtft)
12
13    pred_flat_sxtft = val_pred_scaled_sxtft.reshape(-1, num_q_sxtft)
14    actual_flat_sxtft = y_t_val_sxtft.reshape(-1, 1)
15
16    pred_inv_sxtft = target_scaler_sxtft.inverse_transform(pred_flat_sxtft)
17    actual_inv_sxtft = target_scaler_sxtft.inverse_transform(actual_flat_sxtft)
18
19    pred_final_sxtft = pred_inv_sxtft.reshape(val_pred_scaled_sxtft.shape)
20    actual_final_sxtft = actual_inv_sxtft.reshape(y_t_val_sxtft.shape)
21    print("Predictions and actuals inverse transformed.")
22else:
23    print("Warning: Target scaler not found. Plotting scaled values.")
24    pred_final_sxtft = val_pred_scaled_sxtft
25    actual_final_sxtft = y_t_val_sxtft
26
27# --- Visualization for one sample item ---
28sample_idx_sxtft = 0 # Plot the first validation sequence
29if len(actual_final_sxtft) > sample_idx_sxtft:
30    actual_sxtft_item = actual_final_sxtft[sample_idx_sxtft, :, 0]
31    pred_q_sxtft_item = pred_final_sxtft[sample_idx_sxtft, :, :]
32    steps_axis_sxtft = np.arange(1, forecast_horizons_sxtft + 1)
33
34    plt.figure(figsize=(12, 6))
35    plt.plot(steps_axis_sxtft, actual_sxtft_item,
36             label='Actual Sales', marker='o', linestyle='--')
37    plt.plot(steps_axis_sxtft, pred_q_sxtft_item[:, 1], # Median
38             label='Median Forecast (q=0.5)', marker='x')
39    plt.fill_between(
40        steps_axis_sxtft, pred_q_sxtft_item[:, 0], pred_q_sxtft_item[:, 2],
41        color='lightcoral', alpha=0.4,
42        label='Interval (q0.1-q0.9)'
43    )
44    plt.title(f'SuperXTFT Quantile Forecast (Sample {sample_idx_sxtft})')
45    plt.xlabel('Forecast Step'); plt.ylabel(target_col_sxtft)
46    plt.legend(); plt.grid(True); plt.tight_layout()
47    # fig_path_sxtft = os.path.join(
48    # exercise_output_dir_super_xtft,
49    # "exercise_super_xtft_forecast.png")
50    # plt.savefig(fig_path_sxtft)
51    plt.show()
52    print("\nSuperXTFT quantile forecast plot generated.")
53else:
54    print("\nNot enough validation samples to plot.")

Expected Plot 8.8:

SuperXTFT Quantile Forecast Exercise

Visualization of the SuperXTFT quantile forecast (median and interval) against actual validation data.

Discussion of Exercise

Congratulations! You have successfully completed the end-to-end workflow for using the advanced SuperXTFT model.

This exercise has demonstrated several key points:

  • The data preparation steps (feature definition, scaling, and sequence generation) are identical to those for the standard XTFT model. This makes it easy to upgrade your workflow to use this more powerful architecture without changing your data pipeline.

  • Instantiation and compilation follow the same familiar pattern, using a consistent set of core hyperparameters.

  • The key enhancements of SuperXTFT—its input Variable Selection Networks for feature selection and its post-attention Gated Residual Networks for deeper processing—are seamlessly integrated within its internal architecture. This allows you to leverage its additional power without altering your core training and prediction code.

You have successfully trained the most powerful model in the fusionlab-learn TFT family. For new projects, it is often a good strategy to start with the standard XTFT as a robust baseline, and then upgrade to SuperXTFT when you need to push for the highest possible performance, especially on datasets with many features or complex underlying dynamics.