Advanced Forecasting with XTFT

This example demonstrates using the more advanced XTFT model for a multi-step quantile forecasting task. XTFT is designed to handle complex scenarios involving static features (e.g., item ID, location attributes), dynamic historical features (e.g., past sales, sensor readings), and known future inputs (e.g., planned promotions, future calendar events).

We will walk through the process step-by-step:

  1. Generate synthetic multi-variate time series data for multiple items.

  2. Define static, dynamic, future, and target features.

  3. Scale numerical features.

  4. Use the reshape_xtft_data() utility to prepare sequences suitable for XTFT.

  5. Split the data into training and validation sets.

  6. Define and compile an XTFT model with quantile outputs.

  7. Train the model.

  8. Make predictions and inverse transform them.

  9. Visualize the quantile predictions for a sample item.

Prerequisites

Ensure you have fusionlab-learn and its dependencies installed:

pip install fusionlab-learn matplotlib scikit-learn joblib

Step 1: Imports and Setup

First, we import the necessary libraries, including TensorFlow, Pandas, NumPy, scikit-learn for scaling, Matplotlib for plotting, and the required components from fusionlab. We also suppress common warnings and logs for cleaner output.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5from sklearn.model_selection import train_test_split
 6from sklearn.preprocessing import StandardScaler
 7import os
 8import joblib # For saving/loading scalers
 9
10# FusionLab imports
11from fusionlab.nn.transformers import XTFT
12from fusionlab.nn.utils import reshape_xtft_data
13from fusionlab.nn.losses import combined_quantile_loss
14
15# Suppress warnings and TF logs for cleaner output
16import warnings
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'): # Check for autograph availability
20    tf.autograph.set_verbosity(0)
21
22# Configuration for outputs
23output_dir_xtft = "./xtft_advanced_example_output"
24os.makedirs(output_dir_xtft, exist_ok=True)
25
26print("Libraries imported and TensorFlow logs configured.")

Step 2: Generate Synthetic Data

We create a sample dataset simulating monthly sales for multiple items over several years. This dataset includes static features (ItemID), dynamic features (Month, Temperature, PrevMonthSales), known future features (PlannedPromotion), and the target (Sales).

 1n_items = 3
 2n_timesteps = 36 # 3 years of monthly data
 3rng_seed = 42
 4np.random.seed(rng_seed) # For reproducibility
 5
 6date_rng = pd.date_range(
 7    start='2020-01-01', periods=n_timesteps, freq='MS' # Month Start
 8    )
 9df_list = []
10
11for item_id in range(n_items):
12    time_idx = np.arange(n_timesteps)
13    # Base sales with trend, seasonality, and item-specific factor
14    sales = (
15        100 + item_id * 50 + time_idx * (2 + item_id * 0.5) +
16        20 * np.sin(2 * np.pi * time_idx / 12) + # Yearly seasonality
17        np.random.normal(0, 10, n_timesteps) # Noise
18    )
19    # Simulated temperature (dynamic)
20    temp = (15 + 10 * np.sin(2 * np.pi * (time_idx % 12) / 12 + np.pi) +
21            np.random.normal(0, 2, n_timesteps))
22    # Simulated planned promotion (future known)
23    promo = np.random.randint(0, 2, n_timesteps)
24
25    item_df = pd.DataFrame({
26        'Date': date_rng,
27        'ItemID': f'item_{item_id}', # String ItemID for grouping
28        'Month': date_rng.month,     # Can be dynamic & future
29        'Temperature': temp,
30        'PlannedPromotion': promo,
31        'Sales': sales
32    })
33    # Create lagged sales (dynamic history)
34    item_df['PrevMonthSales'] = item_df['Sales'].shift(1)
35    df_list.append(item_df)
36
37df_raw = pd.concat(df_list).dropna().reset_index(drop=True)
38print(f"Generated raw data shape: {df_raw.shape}")
39print("Sample of generated data:")
40print(df_raw.head())

Step 3: Define Features and Scale Numerics

We explicitly define which columns correspond to static, dynamic past, known future, and target roles. Numerical features are scaled using StandardScaler. The scaler for the target variable is stored for later inverse transformation of predictions.

 1target_col = 'Sales'
 2dt_col = 'Date' # Datetime column for reshaping
 3# ItemID is the primary static identifier for grouping
 4static_cols = ['ItemID']
 5# Dynamic features: Month, Temperature, and lagged sales
 6dynamic_cols = ['Month', 'Temperature', 'PrevMonthSales']
 7# Future features: Planned promotions and Month (known ahead)
 8future_cols = ['PlannedPromotion', 'Month']
 9# Column for grouping sequences by item
10spatial_cols = ['ItemID']
11
12# Scale numerical features (excluding ItemID, Month, PlannedPromotion)
13# Target 'Sales' is also scaled.
14scalers = {} # To store scalers for different columns
15num_cols_to_scale = ['Temperature', 'PrevMonthSales', 'Sales']
16
17df_scaled = df_raw.copy()
18for col in num_cols_to_scale:
19    if col in df_scaled.columns:
20        scaler = StandardScaler()
21        df_scaled[col] = scaler.fit_transform(df_scaled[[col]])
22        scalers[col] = scaler # Store the fitted scaler
23        print(f"Scaled column: {col}")
24    else:
25        print(f"Warning: Column '{col}' not found for scaling.")
26
27# Save scalers (important for inference)
28scalers_path = os.path.join(output_dir_xtft, "xtft_scalers.joblib")
29joblib.dump(scalers, scalers_path)
30print(f"\nScalers saved to {scalers_path}")

Step 4: Prepare Sequences using reshape_xtft_data

The reshape_xtft_data() utility transforms the processed DataFrame into the specific input arrays required by XTFT. It creates rolling windows, groups by spatial_cols (ItemID), and separates features into static, dynamic, future, and target arrays.

 1time_steps = 12         # Use 1 year of history as lookback
 2forecast_horizons = 6   # Predict next 6 months
 3
 4# Note: 'ItemID' (string) needs to be numerically encoded if used
 5# directly as a feature by the model's embedding layers.
 6# For reshape_xtft_data, it's used for grouping. If also a static
 7# feature, ensure it's numerical or handle encoding before this step.
 8# Here, we assume the model's VSN/Embedding can handle integer IDs if
 9# 'ItemID' was label encoded and passed in static_cols.
10# For simplicity, we'll assume ItemID is handled by grouping and not
11# directly as a numerical static feature in this step, unless label encoded.
12# If ItemID is to be a feature, it should be label encoded first.
13# For this example, we'll use a placeholder if ItemID is not numeric.
14# A more robust approach would be to LabelEncode 'ItemID' before this.
15
16# Let's ensure static_cols passed to reshape_xtft_data are numeric
17# If ItemID is the only static col and it's string, pass empty list or encoded.
18# For this example, let's assume no additional static *features* besides grouping.
19# If you had other numerical static features, list them.
20processed_static_cols = [] # Example: if ItemID is only for grouping
21# If ItemID were label encoded:
22df_scaled['ItemID_Encoded'] = LabelEncoder().fit_transform(df_scaled['ItemID'])
23processed_static_cols = ['ItemID_Encoded']
24
25static_data, dynamic_data, future_data, target_data = reshape_xtft_data(
26    df=df_scaled,
27    dt_col=dt_col,
28    target_col=target_col,
29    dynamic_cols=dynamic_cols,
30    static_cols=processed_static_cols, # Pass empty or encoded static features
31    future_cols=future_cols,
32    spatial_cols=spatial_cols, # Group by ItemID
33    time_steps=time_steps,
34    forecast_horizons=forecast_horizons,
35    verbose=1 # Show resulting shapes
36)
37# target_data from reshape_xtft_data is (N, H, 1)

Step 5: Train/Validation Split of Sequences

The generated sequence arrays are split into training and validation sets. A simple chronological split on the sequences is used here. Inputs for the model are packaged into lists in the order [static, dynamic, future].

 1val_split_fraction = 0.2
 2# Check if any data was generated
 3if target_data is None or target_data.shape[0] == 0:
 4    raise ValueError("No sequences were generated. Check data and parameters.")
 5
 6n_samples = target_data.shape[0]
 7split_idx = int(n_samples * (1 - val_split_fraction))
 8
 9# Handle cases where static_data might be None
10X_train_static = static_data[:split_idx] if static_data is not None else None
11X_val_static = static_data[split_idx:] if static_data is not None else None
12
13X_train_dynamic, X_val_dynamic = dynamic_data[:split_idx], dynamic_data[split_idx:]
14X_train_future, X_val_future = future_data[:split_idx], future_data[split_idx:]
15y_train, y_val = target_data[:split_idx], target_data[split_idx:]
16
17train_inputs = [X_train_static, X_train_dynamic, X_train_future]
18val_inputs = [X_val_static, X_val_dynamic, X_val_future]
19
20print(f"\nData split into Train/Validation sequences:")
21print(f"  Train samples: {len(y_train)}")
22print(f"  Validation samples: {len(y_val)}")

Step 6: Define XTFT Model for Quantile Forecast

Instantiate the XTFT model. Input dimensions are derived from the prepared data arrays. Configure for quantile forecasting and set relevant XTFT hyperparameters.

 1quantiles_to_predict = [0.1, 0.5, 0.9]
 2output_dim_model = 1 # Predicting univariate 'Sales'
 3
 4# Determine input dimensions for the model
 5s_dim = X_train_static.shape[-1] if X_train_static is not None else 0
 6d_dim = X_train_dynamic.shape[-1]
 7f_dim = X_train_future.shape[-1] if X_train_future is not None else 0
 8
 9model = XTFT(
10    static_input_dim=s_dim,
11    dynamic_input_dim=d_dim,
12    future_input_dim=f_dim,
13    forecast_horizon=forecast_horizons,
14    quantiles=quantiles_to_predict,
15    output_dim=output_dim_model,
16    # Example XTFT Hyperparameters (these should be tuned)
17    embed_dim=16,
18    lstm_units=32,
19    attention_units=16,
20    hidden_units=32,
21    num_heads=2, # Reduced for speed
22    dropout_rate=0.1,
23    max_window_size=time_steps, # Can be different from time_steps
24    memory_size=20, # Reduced for speed
25    scales=[1, 3]   # Example multi-scale config
26)
27print("\nXTFT model instantiated for quantile forecast.")

Step 7: Compile and Train the Model

Compile the model with an Adam optimizer and the combined_quantile_loss(). Train for a few epochs for this demonstration.

 1loss_fn = combined_quantile_loss(quantiles=quantiles_to_predict)
 2model.compile(
 3    optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
 4    loss=loss_fn
 5    )
 6print("XTFT model compiled with quantile loss.")
 7
 8# Dummy call to build model and print summary (optional)
 9# Ensure inputs are correctly structured (list of 3, Nones allowed if dims are 0)
10dummy_s = tf.zeros((1, s_dim)) if s_dim > 0 else None
11dummy_d = tf.zeros((1, time_steps, d_dim))
12dummy_f = tf.zeros((1, time_steps + forecast_horizons, f_dim)) if f_dim > 0 else None
13# model([dummy_s, dummy_d, dummy_f])
14# model.summary(line_length=100)
15
16
17print("\nStarting XTFT model training (few epochs for demo)...")
18history = model.fit(
19    train_inputs, # List [Static, Dynamic, Future]
20    y_train,      # Targets
21    validation_data=(val_inputs, y_val),
22    epochs=5,     # Increase for real training
23    batch_size=16,  # Adjust based on memory and dataset size
24    verbose=1
25)
26print("Training finished.")
27if history and history.history.get('val_loss'):
28    print(f"Final validation loss: {history.history['val_loss'][-1]:.4f}")

Step 8: Make Predictions and Inverse Transform

Use the trained model to predict on the validation set. Then, inverse transform the scaled predictions and actuals back to their original units.

 1print("\nMaking quantile predictions on validation set...")
 2predictions_scaled = model.predict(val_inputs, verbose=0)
 3# Shape: (NumValSamples, Horizon, NumQuantiles) if output_dim=1
 4
 5# Inverse Transform Predictions and Actuals
 6# We need the scaler for the 'Sales' (target) column
 7target_scaler = scalers.get(target_col)
 8if target_scaler is None:
 9    print("Warning: Target scaler not found. Plotting scaled values.")
10    predictions_final = predictions_scaled
11    y_val_final = y_val
12else:
13    num_val_samples = X_val_static.shape[0] if X_val_static is not None else X_val_dynamic.shape[0]
14    num_q = len(quantiles_to_predict)
15
16    # Reshape for scaler: (Samples*Horizon, Quantiles/OutputDim)
17    pred_reshaped = predictions_scaled.reshape(-1, num_q * output_dim_model)
18    # If output_dim_model > 1, inverse_transform needs care.
19    # Assuming output_dim_model = 1 for simplicity here.
20    if output_dim_model == 1:
21        predictions_inv = target_scaler.inverse_transform(pred_reshaped)
22        predictions_final = predictions_inv.reshape(
23            num_val_samples, forecast_horizons, num_q
24        )
25        # Inverse transform actuals
26        y_val_reshaped = y_val.reshape(-1, output_dim_model)
27        y_val_inv = target_scaler.inverse_transform(y_val_reshaped)
28        y_val_final = y_val_inv.reshape(
29            num_val_samples, forecast_horizons, output_dim_model
30        )
31        print("Predictions and actuals inverse transformed.")
32    else: # output_dim > 1, inverse transform is more complex
33        print("Inverse transform for multi-output quantiles not shown, plotting scaled.")
34        predictions_final = predictions_scaled
35        y_val_final = y_val

Step 9: Visualize Forecast for One Item

Plot the actual sales and the predicted quantiles (median line plus shaded interval) for one sample item from the validation set.

 1# Select an item and its first sequence in the validation set for plotting
 2# This requires ItemID to be part of X_val_static if it was numerically encoded
 3# For simplicity, we'll plot the first validation sequence.
 4sample_to_plot_idx = 0
 5
 6actual_vals_item = y_val_final[sample_to_plot_idx, :, 0] # Assuming output_dim=1
 7pred_quantiles_item = predictions_final[sample_to_plot_idx, :, :]
 8
 9# Create an approximate time axis for the forecast period
10# This needs the last date of the training data corresponding to this sequence
11# For a generic plot, use forecast steps
12forecast_steps_axis = np.arange(1, forecast_horizons + 1)
13
14plt.figure(figsize=(12, 6))
15plt.plot(forecast_steps_axis, actual_vals_item,
16         label='Actual Sales', marker='o', linestyle='--')
17plt.plot(forecast_steps_axis, pred_quantiles_item[:, 1], # Median (0.5 quantile)
18         label='Median Forecast (q=0.5)', marker='x')
19plt.fill_between(
20    forecast_steps_axis,
21    pred_quantiles_item[:, 0], # Lower quantile (q=0.1)
22    pred_quantiles_item[:, 2], # Upper quantile (q=0.9)
23    color='gray', alpha=0.3,
24    label='Prediction Interval (q=0.1 to q=0.9)'
25)
26plt.title(f'XTFT Quantile Forecast (Validation Sample {sample_to_plot_idx})')
27plt.xlabel('Forecast Step into Horizon')
28plt.ylabel(f'{target_col} (Units after Inverse Transform if applied)')
29plt.legend(); plt.grid(True); plt.tight_layout()
30# To save the figure:
31# fig_path = os.path.join(output_dir_xtft, "advanced_xtft_quantile_forecast.png")
32# plt.savefig(fig_path)
33# print(f"Plot saved to {fig_path}")
34plt.show()
35print("\nAdvanced XTFT quantile forecasting example complete.")

Example Output Plot:

Advanced XTFT Quantile Forecast

Visualization of the XTFT quantile forecast (median and interval) against actual validation data for a sample item.