Exercise: Forecasting with Stricter TFT (All Inputs Required)

Welcome to this exercise on using the stricter version of the Temporal Fusion Transformer, TFT, available in fusionlab-learn. This model implementation requires that static, dynamic (past observed), and known future features are all provided as inputs.

We will perform a single-step point forecast to illustrate the specific data preparation and model interaction for this TFT variant.

Learning Objectives:

  • Generate synthetic multi-item time series data with distinct static, dynamic, and future features.

  • Understand how to define feature roles, numerically encode categorical static features (like item identifiers), and scale numerical data.

  • Utilize the reshape_xtft_data() utility to prepare the three separate input arrays (static, dynamic, future) and targets.

  • Correctly structure the input list [static, dynamic, future] for training and prediction with the stricter TFT.

  • Define, compile, and train the TFT model.

  • Make predictions and visualize the results, including inverse transformation of scaled values.

Let’s get started!

Prerequisites

Ensure you have fusionlab-learn and its common dependencies installed. For visualizations, matplotlib is also needed.

pip install fusionlab-learn matplotlib scikit-learn

Step 1: Imports and Setup

First, we import all necessary libraries.

 1import numpy as np
 2import pandas as pd
 3import tensorflow as tf
 4import matplotlib.pyplot as plt
 5from sklearn.preprocessing import StandardScaler, LabelEncoder
 6from sklearn.model_selection import train_test_split
 7import warnings
 8import os
 9
10# FusionLab imports
11from fusionlab.nn.transformers import TFT # The stricter TFT class
12from fusionlab.nn.utils import reshape_xtft_data
13# Import for Keras to recognize custom loss if model was saved with it
14from fusionlab.nn.losses import combined_quantile_loss
15
16# Suppress warnings and TF logs for cleaner output
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'): # Check for autograph availability
20    tf.autograph.set_verbosity(0)
21
22# Directory for saving any output images from this exercise
23exercise_output_dir_tft_strict = "./tft_strict_exercise_outputs"
24os.makedirs(exercise_output_dir_tft_strict, exist_ok=True)
25
26print("Libraries imported and setup complete for stricter TFT exercise.")

Expected Output 1.1:

Libraries imported and setup complete for stricter TFT exercise.

Step 2: Generate Synthetic Multi-Feature Data

We’ll create a synthetic dataset for multiple items. Each item will have: * Static features: ItemID_str (a string identifier) and Category (a numerical category). * Dynamic past features: DayOfWeek and ValueLag1 (lagged target). * Known future features: FutureEvent (a binary indicator) and DayOfWeek. * Target: Value.

 1n_items_ex_strict = 2
 2n_timesteps_per_item_ex_strict = 50
 3rng_seed_ex_strict = 42
 4np.random.seed(rng_seed_ex_strict)
 5tf.random.set_seed(rng_seed_ex_strict)
 6
 7date_rng_ex_strict = pd.date_range(
 8    start='2021-01-01',
 9    periods=n_timesteps_per_item_ex_strict, freq='D'
10    )
11df_list_ex_strict = []
12
13for item_id_num in range(n_items_ex_strict):
14    time_idx = np.arange(n_timesteps_per_item_ex_strict)
15    value = (50 + item_id_num * 10 + time_idx * 0.5 +
16             np.sin(time_idx / 7) * 5 + # Weekly seasonality
17             np.random.normal(0, 2, n_timesteps_per_item_ex_strict))
18    static_category_val = item_id_num + 1
19    future_event_val = (date_rng_ex_strict.dayofweek >= 5).astype(int) # Weekend
20
21    item_df = pd.DataFrame({
22        'Date': date_rng_ex_strict,
23        'ItemID_str': f'item_{item_id_num}', # String ID
24        'Category': static_category_val,    # Numerical static
25        'DayOfWeek': date_rng_ex_strict.dayofweek,
26        'FutureEvent': future_event_val,
27        'Value': value
28    })
29    item_df['ValueLag1'] = item_df['Value'].shift(1)
30    df_list_ex_strict.append(item_df)
31
32df_raw_ex_strict = pd.concat(
33    df_list_ex_strict).dropna().reset_index(drop=True)
34print(f"Generated raw data shape: {df_raw_ex_strict.shape}")
35print("Sample of generated data:")
36print(df_raw_ex_strict.head(3))

Expected Output 2.2:

Generated raw data shape: (98, 7)
Sample of generated data:
       Date ItemID_str  Category  DayOfWeek  FutureEvent      Value  ValueLag1
0 2021-01-02     item_0         1          5            1  50.935330  50.993428
1 2021-01-03     item_0         1          6            1  53.704591  50.935330
2 2021-01-04     item_0         1          0            0  56.623919  53.704591

Step 3: Define Features, Encode Static, and Scale Numerics

We assign columns to their roles. Since the stricter TFT model (and reshape_xtft_data) expects numerical inputs for static features, we’ll LabelEncode the string-based ItemID_str. Then, we scale relevant numerical features.

 1target_col_strict = 'Value'
 2dt_col_strict = 'Date'
 3
 4# Initial column definitions
 5static_cols_def_strict = ['ItemID_str', 'Category']
 6dynamic_cols_def_strict = ['DayOfWeek', 'ValueLag1']
 7future_cols_def_strict = ['FutureEvent', 'DayOfWeek']
 8# For reshape_xtft_data, spatial_cols are used for grouping
 9spatial_cols_for_grouping = ['ItemID_str']
10
11df_processed_strict = df_raw_ex_strict.copy()
12
13# --- Encode ItemID_str (Categorical Static Feature) ---
14le_item_id_ex_strict = LabelEncoder()
15df_processed_strict['ItemID_encoded'] = \
16    le_item_id_ex_strict.fit_transform(df_processed_strict['ItemID_str'])
17print(f"\nEncoded 'ItemID_str' into 'ItemID_encoded'. "
18      f"Classes: {le_item_id_ex_strict.classes_}")
19
20# --- Update static_cols to use the encoded version for the model ---
21static_cols_for_model_strict = ['ItemID_encoded', 'Category']
22# For reshape_xtft_data, grouping can still use original string ID,
23# or you can group by the encoded ID if preferred.
24# If grouping by encoded, ensure it's in df_processed_strict.
25# Here, we'll pass the original string ItemID for grouping to reshape,
26# but use ItemID_encoded as a static *feature*.
27
28# --- Scale Numerical Features ---
29scaler_strict = StandardScaler()
30num_cols_to_scale_strict = ['Value', 'ValueLag1']
31# Ensure columns exist
32num_cols_to_scale_strict = [
33    c for c in num_cols_to_scale_strict if c in df_processed_strict.columns
34    ]
35if num_cols_to_scale_strict:
36    df_processed_strict[num_cols_to_scale_strict] = \
37        scaler_strict.fit_transform(
38            df_processed_strict[num_cols_to_scale_strict]
39            )
40    print("\nNumerical features scaled.")
41else:
42    print("\nNo numerical features found for scaling.")

Expected Output 3.3:

Encoded 'ItemID_str' into 'ItemID_encoded'. Classes: ['item_0' 'item_1']

Numerical features scaled.

Step 4: Prepare Sequences with reshape_xtft_data

Use reshape_xtft_data() to transform the DataFrame. It will use spatial_cols_for_grouping (original ItemID_str) for grouping and static_cols_for_model_strict (including ItemID_encoded) to create the static_data array.

 1time_steps_strict = 7
 2forecast_horizon_strict = 1 # Single-step point forecast
 3
 4static_data_s, dynamic_data_s, future_data_s, target_data_s = \
 5    reshape_xtft_data(
 6        df=df_processed_strict, # Contains ItemID_encoded
 7        dt_col=dt_col_strict,
 8        target_col=target_col_strict,
 9        dynamic_cols=dynamic_cols_def_strict,
10        static_cols=static_cols_for_model_strict, # Use encoded static
11        future_cols=future_cols_def_strict,
12        spatial_cols=spatial_cols_for_grouping, # Group by original ItemID_str
13        time_steps=time_steps_strict,
14        forecast_horizons=forecast_horizon_strict,
15        verbose=0
16    )
17targets_s = target_data_s.astype(np.float32) # Already (N,H,1)
18
19print(f"\nReshaped Data Shapes for Stricter TFT:")
20print(f"  Static : {static_data_s.shape}")
21print(f"  Dynamic: {dynamic_data_s.shape}")
22print(f"  Future : {future_data_s.shape}")
23print(f"  Target : {targets_s.shape}")
Expected Output 4.4:

(Shapes depend on n_items, n_timesteps, time_steps, forecast_horizon)

Reshaped Data Shapes for Stricter TFT:
  Static : (84, 2)
  Dynamic: (84, 7, 2)
  Future : (84, 8, 2)
  Target : (84, 1, 1)

Step 5: Train/Validation Split of Sequences

Split the generated sequence arrays. The input for the model will be a list of three non-None arrays: [X_static, X_dynamic, X_future].

 1val_split_s_frac = 0.2
 2n_samples_s_total = static_data_s.shape[0]
 3split_idx_s_val = int(n_samples_s_total * (1 - val_split_s_frac))
 4
 5X_s_train_s, X_s_val_s = static_data_s[:split_idx_s_val], static_data_s[split_idx_s_val:]
 6X_d_train_s, X_d_val_s = dynamic_data_s[:split_idx_s_val], dynamic_data_s[split_idx_s_val:]
 7X_f_train_s, X_f_val_s = future_data_s[:split_idx_s_val], future_data_s[split_idx_s_val:]
 8y_t_train_s, y_t_val_s = targets_s[:split_idx_s_val], targets_s[split_idx_s_val:]
 9
10# Package inputs as the REQUIRED list [static, dynamic, future]
11train_inputs_strict = [X_s_train_s, X_d_train_s, X_f_train_s]
12val_inputs_strict = [X_s_val_s, X_d_val_s, X_f_val_s]
13
14print("\nSequence data split for stricter TFT.")
15print(f"  Train samples: {len(y_t_train_s)}")
16print(f"  Validation samples: {len(y_t_val_s)}")

Expected Output 5.5:

Sequence data split for stricter TFT.
  Train samples: 67
  Validation samples: 17

Step 6: Define and Compile Stricter TFT Model

Instantiate the TFT class. All three input dimensions (static_input_dim, dynamic_input_dim, future_input_dim) must be provided and must be > 0.

 1model_strict_ex = TFT( # Using the stricter TFT class
 2    static_input_dim=static_data_s.shape[-1],
 3    dynamic_input_dim=dynamic_data_s.shape[-1],
 4    future_input_dim=future_data_s.shape[-1],
 5    forecast_horizon=forecast_horizon_strict,
 6    output_dim=1, # Predicting a single value
 7    hidden_units=16, num_heads=2,
 8    num_lstm_layers=1, lstm_units=16,
 9    quantiles=None # Point forecast
10)
11print("\nStricter TFT model instantiated for point forecast.")
12
13model_strict_ex.compile(optimizer='adam', loss='mse')
14print("Model compiled successfully.")

Expected Output 6.6:

Stricter TFT model instantiated for point forecast.
Model compiled successfully.

Step 7: Train the Stricter TFT Model

 1print("\nStarting stricter TFT model training...")
 2history_strict_ex = model_strict_ex.fit(
 3    train_inputs_strict, # Pass the list [static, dynamic, future]
 4    y_t_train_s,
 5    validation_data=(val_inputs_strict, y_t_val_s),
 6    epochs=5, batch_size=16, verbose=1
 7)
 8print("Training finished.")
 9if history_strict_ex and history_strict_ex.history.get('val_loss'):
10    val_loss = history_strict_ex.history['val_loss'][-1]
11    print(f"Final validation loss: {val_loss:.4f}")
Expected Output 7.7:

(Output will show Keras training progress)

Starting stricter TFT model training...
Epoch 1/5
5/5 [==============================] - 13s 511ms/step - loss: 1.5969 - val_loss: 0.8108
Epoch 2/5
5/5 [==============================] - 0s 16ms/step - loss: 0.7010 - val_loss: 1.9081
Epoch 3/5
5/5 [==============================] - 0s 17ms/step - loss: 0.4777 - val_loss: 1.8109
Epoch 4/5
5/5 [==============================] - 0s 16ms/step - loss: 0.4485 - val_loss: 1.0865
Epoch 5/5
5/5 [==============================] - 0s 17ms/step - loss: 0.4132 - val_loss: 0.7321
Training finished.
Final validation loss: 0.7321

Step 8: Make Predictions and Visualize

Use the trained model to predict and then visualize the results after inverse transforming.

 1print("\nMaking predictions with stricter TFT...")
 2val_predictions_scaled_s = model_strict_ex.predict(
 3    val_inputs_strict, verbose=0
 4    )
 5
 6# Inverse transform predictions and actuals
 7target_scaler_s = scalers_ex.get(target_col_strict)
 8if target_scaler_s:
 9    dummy_pred_s = np.zeros((len(val_predictions_scaled_s.flatten()),
10                             len(num_cols_to_scale_strict)))
11    target_idx_s = num_cols_to_scale_strict.index(target_col_strict)
12    dummy_pred_s[:, target_idx_s] = val_predictions_scaled_s.flatten()
13    val_pred_inv_s = target_scaler_s.inverse_transform(
14        dummy_pred_s)[:, target_idx_s]
15    val_pred_final_s = val_pred_inv_s.reshape(val_predictions_scaled_s.shape)
16
17    dummy_actual_s = np.zeros((len(y_t_val_s.flatten()),
18                               len(num_cols_to_scale_strict)))
19    dummy_actual_s[:, target_idx_s] = y_t_val_s.flatten()
20    val_actual_inv_s = target_scaler_s.inverse_transform(
21        dummy_actual_s)[:, target_idx_s]
22    val_actual_final_s = val_actual_inv_s.reshape(y_t_val_s.shape)
23    print("Predictions and actuals inverse transformed.")
24else:
25    print("Warning: Target scaler not found. Plotting scaled values.")
26    val_pred_final_s = val_predictions_scaled_s
27    val_actual_final_s = y_t_val_s
28
29# --- Visualization (for the first item in validation set) ---
30first_val_item_id_enc = X_s_val_s[0, static_cols_for_model_strict.index('ItemID_encoded')]
31item_mask_val_s = (X_s_val_s[:, static_cols_for_model_strict.index('ItemID_encoded')] == \
32                   first_val_item_id_enc)
33
34item_preds_s = val_pred_final_s[item_mask_val_s, 0, 0]
35item_actuals_s = val_actual_final_s[item_mask_val_s, 0, 0]
36
37plt.figure(figsize=(12, 6))
38plt.plot(item_actuals_s,
39         label=f'Actual (Item Encoded: {int(first_val_item_id_enc)})',
40         marker='o', linestyle='--')
41plt.plot(item_preds_s,
42         label=f'Predicted (Item Encoded: {int(first_val_item_id_enc)})',
43         marker='x')
44plt.title(f'Stricter TFT Point Forecast (Validation Item - Inverse Transformed)')
45plt.xlabel('Sequence Index in Validation Set for this Item')
46plt.ylabel('Value (Inverse Transformed)')
47plt.legend(); plt.grid(True); plt.tight_layout()
48# fig_path_strict_ex = os.path.join(
49# exercise_output_dir_tft_strict,
50# "exercise_tft_required_inputs.png"
51# )
52# plt.savefig(fig_path_strict_ex)
53# print(f"\nPlot saved to {fig_path_strict_ex}")
54plt.show()
55print("Plot generated for stricter TFT.")

Expected Plot 8.8:

Stricter TFT Point Forecast Exercise Results

Visualization of the point forecast from the stricter TFT model against actual validation data for a specific item.

Discussion of Exercise

In this exercise, you learned how to: * Prepare a multi-item dataset with distinct static, dynamic, and

future features.

  • Numerically encode categorical static identifiers like ItemID using LabelEncoder.

  • Use reshape_xtft_data() to generate the three required input arrays (static_data, dynamic_data, future_data) for the stricter TFT model.

  • Instantiate and train the stricter TFT, ensuring all three *_input_dim parameters are provided.

  • Correctly structure the input to fit and predict as a list [static_array, dynamic_array, future_array].

This example highlights the data preparation and usage pattern for the TFT model variant that mandates all three types of input features.