Exercise: Forecasting with Stricter TFT (All Inputs Required)¶
Welcome to this exercise on using the stricter version of the
Temporal Fusion Transformer, TFT,
available in fusionlab-learn. This model implementation requires
that static, dynamic (past observed), and known future
features are all provided as inputs.
We will perform a single-step point forecast to illustrate the specific data preparation and model interaction for this TFT variant.
Learning Objectives:
Generate synthetic multi-item time series data with distinct static, dynamic, and future features.
Understand how to define feature roles, numerically encode categorical static features (like item identifiers), and scale numerical data.
Utilize the
reshape_xtft_data()utility to prepare the three separate input arrays (static, dynamic, future) and targets.Correctly structure the input list [static, dynamic, future] for training and prediction with the stricter TFT.
Define, compile, and train the TFT model.
Make predictions and visualize the results, including inverse transformation of scaled values.
Let’s get started!
Prerequisites¶
Ensure you have fusionlab-learn and its common dependencies
installed. For visualizations, matplotlib is also needed.
pip install fusionlab-learn matplotlib scikit-learn
Step 1: Imports and Setup¶
First, we import all necessary libraries.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import matplotlib.pyplot as plt
5from sklearn.preprocessing import StandardScaler, LabelEncoder
6from sklearn.model_selection import train_test_split
7import warnings
8import os
9
10# FusionLab imports
11from fusionlab.nn.transformers import TFT # The stricter TFT class
12from fusionlab.nn.utils import reshape_xtft_data
13# Import for Keras to recognize custom loss if model was saved with it
14from fusionlab.nn.losses import combined_quantile_loss
15
16# Suppress warnings and TF logs for cleaner output
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'): # Check for autograph availability
20 tf.autograph.set_verbosity(0)
21
22# Directory for saving any output images from this exercise
23exercise_output_dir_tft_strict = "./tft_strict_exercise_outputs"
24os.makedirs(exercise_output_dir_tft_strict, exist_ok=True)
25
26print("Libraries imported and setup complete for stricter TFT exercise.")
Expected Output 1.1:
Libraries imported and setup complete for stricter TFT exercise.
Step 2: Generate Synthetic Multi-Feature Data¶
We’ll create a synthetic dataset for multiple items. Each item will have: * Static features: ItemID_str (a string identifier) and Category (a numerical category). * Dynamic past features: DayOfWeek and ValueLag1 (lagged target). * Known future features: FutureEvent (a binary indicator) and DayOfWeek. * Target: Value.
1n_items_ex_strict = 2
2n_timesteps_per_item_ex_strict = 50
3rng_seed_ex_strict = 42
4np.random.seed(rng_seed_ex_strict)
5tf.random.set_seed(rng_seed_ex_strict)
6
7date_rng_ex_strict = pd.date_range(
8 start='2021-01-01',
9 periods=n_timesteps_per_item_ex_strict, freq='D'
10 )
11df_list_ex_strict = []
12
13for item_id_num in range(n_items_ex_strict):
14 time_idx = np.arange(n_timesteps_per_item_ex_strict)
15 value = (50 + item_id_num * 10 + time_idx * 0.5 +
16 np.sin(time_idx / 7) * 5 + # Weekly seasonality
17 np.random.normal(0, 2, n_timesteps_per_item_ex_strict))
18 static_category_val = item_id_num + 1
19 future_event_val = (date_rng_ex_strict.dayofweek >= 5).astype(int) # Weekend
20
21 item_df = pd.DataFrame({
22 'Date': date_rng_ex_strict,
23 'ItemID_str': f'item_{item_id_num}', # String ID
24 'Category': static_category_val, # Numerical static
25 'DayOfWeek': date_rng_ex_strict.dayofweek,
26 'FutureEvent': future_event_val,
27 'Value': value
28 })
29 item_df['ValueLag1'] = item_df['Value'].shift(1)
30 df_list_ex_strict.append(item_df)
31
32df_raw_ex_strict = pd.concat(
33 df_list_ex_strict).dropna().reset_index(drop=True)
34print(f"Generated raw data shape: {df_raw_ex_strict.shape}")
35print("Sample of generated data:")
36print(df_raw_ex_strict.head(3))
Expected Output 2.2:
Generated raw data shape: (98, 7)
Sample of generated data:
Date ItemID_str Category DayOfWeek FutureEvent Value ValueLag1
0 2021-01-02 item_0 1 5 1 50.935330 50.993428
1 2021-01-03 item_0 1 6 1 53.704591 50.935330
2 2021-01-04 item_0 1 0 0 56.623919 53.704591
Step 3: Define Features, Encode Static, and Scale Numerics¶
We assign columns to their roles. Since the stricter TFT model (and reshape_xtft_data) expects numerical inputs for static features, we’ll LabelEncode the string-based ItemID_str. Then, we scale relevant numerical features.
1target_col_strict = 'Value'
2dt_col_strict = 'Date'
3
4# Initial column definitions
5static_cols_def_strict = ['ItemID_str', 'Category']
6dynamic_cols_def_strict = ['DayOfWeek', 'ValueLag1']
7future_cols_def_strict = ['FutureEvent', 'DayOfWeek']
8# For reshape_xtft_data, spatial_cols are used for grouping
9spatial_cols_for_grouping = ['ItemID_str']
10
11df_processed_strict = df_raw_ex_strict.copy()
12
13# --- Encode ItemID_str (Categorical Static Feature) ---
14le_item_id_ex_strict = LabelEncoder()
15df_processed_strict['ItemID_encoded'] = \
16 le_item_id_ex_strict.fit_transform(df_processed_strict['ItemID_str'])
17print(f"\nEncoded 'ItemID_str' into 'ItemID_encoded'. "
18 f"Classes: {le_item_id_ex_strict.classes_}")
19
20# --- Update static_cols to use the encoded version for the model ---
21static_cols_for_model_strict = ['ItemID_encoded', 'Category']
22# For reshape_xtft_data, grouping can still use original string ID,
23# or you can group by the encoded ID if preferred.
24# If grouping by encoded, ensure it's in df_processed_strict.
25# Here, we'll pass the original string ItemID for grouping to reshape,
26# but use ItemID_encoded as a static *feature*.
27
28# --- Scale Numerical Features ---
29scaler_strict = StandardScaler()
30num_cols_to_scale_strict = ['Value', 'ValueLag1']
31# Ensure columns exist
32num_cols_to_scale_strict = [
33 c for c in num_cols_to_scale_strict if c in df_processed_strict.columns
34 ]
35if num_cols_to_scale_strict:
36 df_processed_strict[num_cols_to_scale_strict] = \
37 scaler_strict.fit_transform(
38 df_processed_strict[num_cols_to_scale_strict]
39 )
40 print("\nNumerical features scaled.")
41else:
42 print("\nNo numerical features found for scaling.")
Expected Output 3.3:
Encoded 'ItemID_str' into 'ItemID_encoded'. Classes: ['item_0' 'item_1']
Numerical features scaled.
Step 4: Prepare Sequences with reshape_xtft_data¶
Use reshape_xtft_data() to transform the
DataFrame. It will use spatial_cols_for_grouping (original ItemID_str)
for grouping and static_cols_for_model_strict (including
ItemID_encoded) to create the static_data array.
1time_steps_strict = 7
2forecast_horizon_strict = 1 # Single-step point forecast
3
4static_data_s, dynamic_data_s, future_data_s, target_data_s = \
5 reshape_xtft_data(
6 df=df_processed_strict, # Contains ItemID_encoded
7 dt_col=dt_col_strict,
8 target_col=target_col_strict,
9 dynamic_cols=dynamic_cols_def_strict,
10 static_cols=static_cols_for_model_strict, # Use encoded static
11 future_cols=future_cols_def_strict,
12 spatial_cols=spatial_cols_for_grouping, # Group by original ItemID_str
13 time_steps=time_steps_strict,
14 forecast_horizons=forecast_horizon_strict,
15 verbose=0
16 )
17targets_s = target_data_s.astype(np.float32) # Already (N,H,1)
18
19print(f"\nReshaped Data Shapes for Stricter TFT:")
20print(f" Static : {static_data_s.shape}")
21print(f" Dynamic: {dynamic_data_s.shape}")
22print(f" Future : {future_data_s.shape}")
23print(f" Target : {targets_s.shape}")
- Expected Output 4.4:
(Shapes depend on n_items, n_timesteps, time_steps, forecast_horizon)
Reshaped Data Shapes for Stricter TFT:
Static : (84, 2)
Dynamic: (84, 7, 2)
Future : (84, 8, 2)
Target : (84, 1, 1)
Step 5: Train/Validation Split of Sequences¶
Split the generated sequence arrays. The input for the model will be a list of three non-None arrays: [X_static, X_dynamic, X_future].
1val_split_s_frac = 0.2
2n_samples_s_total = static_data_s.shape[0]
3split_idx_s_val = int(n_samples_s_total * (1 - val_split_s_frac))
4
5X_s_train_s, X_s_val_s = static_data_s[:split_idx_s_val], static_data_s[split_idx_s_val:]
6X_d_train_s, X_d_val_s = dynamic_data_s[:split_idx_s_val], dynamic_data_s[split_idx_s_val:]
7X_f_train_s, X_f_val_s = future_data_s[:split_idx_s_val], future_data_s[split_idx_s_val:]
8y_t_train_s, y_t_val_s = targets_s[:split_idx_s_val], targets_s[split_idx_s_val:]
9
10# Package inputs as the REQUIRED list [static, dynamic, future]
11train_inputs_strict = [X_s_train_s, X_d_train_s, X_f_train_s]
12val_inputs_strict = [X_s_val_s, X_d_val_s, X_f_val_s]
13
14print("\nSequence data split for stricter TFT.")
15print(f" Train samples: {len(y_t_train_s)}")
16print(f" Validation samples: {len(y_t_val_s)}")
Expected Output 5.5:
Sequence data split for stricter TFT.
Train samples: 67
Validation samples: 17
Step 6: Define and Compile Stricter TFT Model¶
Instantiate the TFT class. All
three input dimensions (static_input_dim, dynamic_input_dim,
future_input_dim) must be provided and must be > 0.
1model_strict_ex = TFT( # Using the stricter TFT class
2 static_input_dim=static_data_s.shape[-1],
3 dynamic_input_dim=dynamic_data_s.shape[-1],
4 future_input_dim=future_data_s.shape[-1],
5 forecast_horizon=forecast_horizon_strict,
6 output_dim=1, # Predicting a single value
7 hidden_units=16, num_heads=2,
8 num_lstm_layers=1, lstm_units=16,
9 quantiles=None # Point forecast
10)
11print("\nStricter TFT model instantiated for point forecast.")
12
13model_strict_ex.compile(optimizer='adam', loss='mse')
14print("Model compiled successfully.")
Expected Output 6.6:
Stricter TFT model instantiated for point forecast.
Model compiled successfully.
Step 7: Train the Stricter TFT Model¶
1print("\nStarting stricter TFT model training...")
2history_strict_ex = model_strict_ex.fit(
3 train_inputs_strict, # Pass the list [static, dynamic, future]
4 y_t_train_s,
5 validation_data=(val_inputs_strict, y_t_val_s),
6 epochs=5, batch_size=16, verbose=1
7)
8print("Training finished.")
9if history_strict_ex and history_strict_ex.history.get('val_loss'):
10 val_loss = history_strict_ex.history['val_loss'][-1]
11 print(f"Final validation loss: {val_loss:.4f}")
- Expected Output 7.7:
(Output will show Keras training progress)
Starting stricter TFT model training...
Epoch 1/5
5/5 [==============================] - 13s 511ms/step - loss: 1.5969 - val_loss: 0.8108
Epoch 2/5
5/5 [==============================] - 0s 16ms/step - loss: 0.7010 - val_loss: 1.9081
Epoch 3/5
5/5 [==============================] - 0s 17ms/step - loss: 0.4777 - val_loss: 1.8109
Epoch 4/5
5/5 [==============================] - 0s 16ms/step - loss: 0.4485 - val_loss: 1.0865
Epoch 5/5
5/5 [==============================] - 0s 17ms/step - loss: 0.4132 - val_loss: 0.7321
Training finished.
Final validation loss: 0.7321
Step 8: Make Predictions and Visualize¶
Use the trained model to predict and then visualize the results after inverse transforming.
1print("\nMaking predictions with stricter TFT...")
2val_predictions_scaled_s = model_strict_ex.predict(
3 val_inputs_strict, verbose=0
4 )
5
6# Inverse transform predictions and actuals
7target_scaler_s = scalers_ex.get(target_col_strict)
8if target_scaler_s:
9 dummy_pred_s = np.zeros((len(val_predictions_scaled_s.flatten()),
10 len(num_cols_to_scale_strict)))
11 target_idx_s = num_cols_to_scale_strict.index(target_col_strict)
12 dummy_pred_s[:, target_idx_s] = val_predictions_scaled_s.flatten()
13 val_pred_inv_s = target_scaler_s.inverse_transform(
14 dummy_pred_s)[:, target_idx_s]
15 val_pred_final_s = val_pred_inv_s.reshape(val_predictions_scaled_s.shape)
16
17 dummy_actual_s = np.zeros((len(y_t_val_s.flatten()),
18 len(num_cols_to_scale_strict)))
19 dummy_actual_s[:, target_idx_s] = y_t_val_s.flatten()
20 val_actual_inv_s = target_scaler_s.inverse_transform(
21 dummy_actual_s)[:, target_idx_s]
22 val_actual_final_s = val_actual_inv_s.reshape(y_t_val_s.shape)
23 print("Predictions and actuals inverse transformed.")
24else:
25 print("Warning: Target scaler not found. Plotting scaled values.")
26 val_pred_final_s = val_predictions_scaled_s
27 val_actual_final_s = y_t_val_s
28
29# --- Visualization (for the first item in validation set) ---
30first_val_item_id_enc = X_s_val_s[0, static_cols_for_model_strict.index('ItemID_encoded')]
31item_mask_val_s = (X_s_val_s[:, static_cols_for_model_strict.index('ItemID_encoded')] == \
32 first_val_item_id_enc)
33
34item_preds_s = val_pred_final_s[item_mask_val_s, 0, 0]
35item_actuals_s = val_actual_final_s[item_mask_val_s, 0, 0]
36
37plt.figure(figsize=(12, 6))
38plt.plot(item_actuals_s,
39 label=f'Actual (Item Encoded: {int(first_val_item_id_enc)})',
40 marker='o', linestyle='--')
41plt.plot(item_preds_s,
42 label=f'Predicted (Item Encoded: {int(first_val_item_id_enc)})',
43 marker='x')
44plt.title(f'Stricter TFT Point Forecast (Validation Item - Inverse Transformed)')
45plt.xlabel('Sequence Index in Validation Set for this Item')
46plt.ylabel('Value (Inverse Transformed)')
47plt.legend(); plt.grid(True); plt.tight_layout()
48# fig_path_strict_ex = os.path.join(
49# exercise_output_dir_tft_strict,
50# "exercise_tft_required_inputs.png"
51# )
52# plt.savefig(fig_path_strict_ex)
53# print(f"\nPlot saved to {fig_path_strict_ex}")
54plt.show()
55print("Plot generated for stricter TFT.")
Expected Plot 8.8:
Visualization of the point forecast from the stricter TFT model against actual validation data for a specific item.¶
Discussion of Exercise¶
In this exercise, you learned how to: * Prepare a multi-item dataset with distinct static, dynamic, and
future features.
Numerically encode categorical static identifiers like ItemID using LabelEncoder.
Use
reshape_xtft_data()to generate the three required input arrays (static_data, dynamic_data, future_data) for the stricterTFTmodel.Instantiate and train the stricter TFT, ensuring all three *_input_dim parameters are provided.
Correctly structure the input to fit and predict as a list [static_array, dynamic_array, future_array].
This example highlights the data preparation and usage pattern for the TFT model variant that mandates all three types of input features.