Point Forecasting with Stricter TFT (Required Inputs)¶
This example demonstrates how to use the stricter
TFT class implementation.
Unlike the more flexible
TemporalFusionTransformer, this
version strictly requires static, dynamic (past), and
known future features as inputs during initialization and for
model calls.
We will perform a single-step point forecast, showcasing the specific data preparation and model interaction for this TFT variant.
The workflow includes:
Generating synthetic data with distinct static, dynamic, and future features for multiple items.
Defining feature roles, encoding categorical static features, and scaling numerical data.
Using the
reshape_xtft_data()utility to prepare the three separate input arrays (static, dynamic, future) and targets.Splitting the sequence data into training and validation sets.
Defining, compiling, and training the stricter TFT model for point forecasting.
Making predictions using the mandatory three-part input structure.
Visualizing the results.
Prerequisites¶
Ensure you have fusionlab-learn and its dependencies installed:
pip install fusionlab-learn matplotlib scikit-learn
Step 1: Imports and Setup¶
Import standard libraries and the necessary components from
fusionlab, including the stricter TFT model and
reshape_xtft_data.
1import numpy as np
2import pandas as pd
3import tensorflow as tf
4import matplotlib.pyplot as plt
5from sklearn.preprocessing import StandardScaler, LabelEncoder # Added LabelEncoder
6from sklearn.model_selection import train_test_split
7import warnings
8import os
9
10# FusionLab imports
11from fusionlab.nn.transformers import TFT # The stricter TFT class
12from fusionlab.nn.utils import reshape_xtft_data
13# For Keras to recognize custom loss if model was saved with it
14from fusionlab.nn.losses import combined_quantile_loss
15
16# Suppress warnings and TF logs for cleaner output
17warnings.filterwarnings('ignore')
18tf.get_logger().setLevel('ERROR')
19if hasattr(tf, 'autograph'):
20 tf.autograph.set_verbosity(0)
21
22print("Libraries imported and TensorFlow logs configured.")
Step 2: Generate Synthetic Data¶
We create a synthetic dataset for multiple items, including static features (like ItemID, Category), dynamic past features (DayOfWeek, ValueLag1), and known future features (FutureEvent, DayOfWeek). ItemID is generated as a string.
1n_items = 2
2n_timesteps_per_item = 50
3rng_seed = 42
4np.random.seed(rng_seed)
5
6date_rng = pd.date_range(
7 start='2021-01-01', periods=n_timesteps_per_item, freq='D'
8 )
9df_list = []
10
11for item_id_num in range(n_items): # Use numerical id for generation
12 time_idx = np.arange(n_timesteps_per_item)
13 value = (50 + item_id_num * 10 + time_idx * 0.5 +
14 np.sin(time_idx / 7) * 5 +
15 np.random.normal(0, 2, n_timesteps_per_item))
16 static_category_val = item_id_num + 1
17 future_event_val = (date_rng.dayofweek >= 5).astype(int)
18
19 item_df = pd.DataFrame({
20 'Date': date_rng,
21 'ItemID_str': f'item_{item_id_num}', # String ID for raw data
22 'Category': static_category_val, # Numerical static
23 'DayOfWeek': date_rng.dayofweek,
24 'FutureEvent': future_event_val,
25 'Value': value
26 })
27 item_df['ValueLag1'] = item_df['Value'].shift(1)
28 df_list.append(item_df)
29
30df_raw = pd.concat(df_list).dropna().reset_index(drop=True)
31print(f"Generated raw data shape: {df_raw.shape}")
32print("Sample of generated data:")
33print(df_raw.head())
Step 3: Define Features, Encode Categorical Static, & Scale Numerics¶
Assign columns to their roles. Crucially, encode string-based static features like `ItemID_str` into numerical representations before scaling and reshaping.
1target_col = 'Value'
2dt_col = 'Date'
3
4# Initial column definitions
5# ItemID_str is categorical, Category is already numerical static
6static_cols_def = ['ItemID_str', 'Category']
7dynamic_cols_def = ['DayOfWeek', 'ValueLag1']
8future_cols_def = ['FutureEvent', 'DayOfWeek']
9spatial_cols_def = ['ItemID_str'] # Group by original string ID
10
11df_processed = df_raw.copy()
12
13# --- Encode ItemID_str (Categorical Static Feature) ---
14le_item_id = LabelEncoder()
15# Create a new numerical column for ItemID
16df_processed['ItemID_encoded'] = le_item_id.fit_transform(
17 df_processed['ItemID_str']
18)
19print(f"\nEncoded 'ItemID_str' into 'ItemID_encoded'. "
20 f"Classes: {le_item_id.classes_}")
21
22# --- Update static_cols to use the encoded version ---
23# 'Category' is already numeric. We'll use 'ItemID_encoded'.
24static_cols_for_model = ['ItemID_encoded', 'Category']
25# Update spatial_cols if grouping should now be by the encoded ID
26# For reshape_xtft_data, spatial_cols are used for grouping and
27# are often also part of static_cols if they are static identifiers.
28# If ItemID_encoded is the primary key for grouping sequences:
29spatial_cols_for_model = ['ItemID_encoded']
30
31
32# --- Scale Numerical Features ---
33# Target 'Value' and 'ValueLag1' are scaled.
34# 'Category', 'DayOfWeek', 'FutureEvent', 'ItemID_encoded' are not scaled here
35# as they are categorical or already identifiers.
36scaler = StandardScaler()
37num_cols_to_scale = ['Value', 'ValueLag1']
38# Ensure these columns exist before trying to scale
39num_cols_to_scale = [c for c in num_cols_to_scale if c in df_processed.columns]
40
41if num_cols_to_scale:
42 df_processed[num_cols_to_scale] = scaler.fit_transform(
43 df_processed[num_cols_to_scale]
44 )
45 print("\nNumerical features scaled.")
46else:
47 print("\nNo numerical features specified or found for scaling.")
Step 4: Prepare Sequences with reshape_xtft_data¶
Use reshape_xtft_data() with the
processed DataFrame (which now has ItemID_encoded) and the
updated column lists.
1time_steps = 7
2forecast_horizon = 1
3
4# Use the updated column lists for model input features
5static_data, dynamic_data, future_data, target_data = reshape_xtft_data(
6 df=df_processed, # Use the DataFrame with ItemID_encoded
7 dt_col=dt_col,
8 target_col=target_col,
9 dynamic_cols=dynamic_cols_def, # Original dynamic cols
10 static_cols=static_cols_for_model, # Use encoded static cols
11 future_cols=future_cols_def, # Original future cols
12 spatial_cols=spatial_cols_for_model, # Group by encoded ItemID
13 time_steps=time_steps,
14 forecast_horizons=forecast_horizon,
15 verbose=0
16)
17targets = target_data.astype(np.float32)
18
19print(f"\nReshaped Data Shapes:")
20print(f" Static : {static_data.shape if static_data is not None else 'None'}")
21print(f" Dynamic: {dynamic_data.shape if dynamic_data is not None else 'None'}")
22print(f" Future : {future_data.shape if future_data is not None else 'None'}")
23print(f" Target : {targets.shape if targets is not None else 'None'}")
Step 5: Train/Validation Split of Sequences¶
Split the generated sequence arrays. The input for the model will be [X_static, X_dynamic, X_future].
1val_split_fraction = 0.2
2if static_data is None or dynamic_data is None or \
3 future_data is None or targets is None:
4 raise ValueError("Data reshaping did not produce all required arrays.")
5
6n_samples = static_data.shape[0]
7split_idx = int(n_samples * (1 - val_split_fraction))
8
9X_train_static, X_val_static = static_data[:split_idx], static_data[split_idx:]
10X_train_dynamic, X_val_dynamic = dynamic_data[:split_idx], dynamic_data[split_idx:]
11X_train_future, X_val_future = future_data[:split_idx], future_data[split_idx:]
12y_train, y_val = targets[:split_idx], targets[split_idx:]
13
14train_inputs = [X_train_static, X_train_dynamic, X_train_future]
15val_inputs = [X_val_static, X_val_dynamic, X_val_future]
16
17print("\nSequence data split into Train/Validation sets.")
18print(f" Train samples: {len(y_train)}")
19print(f" Validation samples: {len(y_val)}")
Step 6: Define and Compile Stricter TFT Model¶
Instantiate the TFT class.
All three input dimensions must be provided.
1model = TFT(
2 static_input_dim=static_data.shape[-1],
3 dynamic_input_dim=dynamic_data.shape[-1],
4 future_input_dim=future_data.shape[-1],
5 forecast_horizon=forecast_horizon,
6 output_dim=1,
7 hidden_units=16, num_heads=2,
8 num_lstm_layers=1, lstm_units=16,
9 quantiles=None # Point forecast
10)
11print("\nStricter TFT model instantiated.")
12model.compile(optimizer='adam', loss='mse')
13print("Model compiled.")
Step 7: Train the Model¶
Train using the 3-element train_inputs list.
1print("\nStarting model training...")
2history = model.fit(
3 train_inputs, y_train,
4 validation_data=(val_inputs, y_val),
5 epochs=5, batch_size=16, verbose=1
6)
7print("Training finished.")
8if history and history.history.get('val_loss'):
9 print(f"Final validation loss: {history.history['val_loss'][-1]:.4f}")
Step 8: Make Predictions and Visualize¶
Predict on the validation set and visualize. Inverse transform for interpretable results.
1print("\nMaking predictions on the validation set...")
2val_predictions_scaled = model.predict(val_inputs, verbose=0)
3
4# Inverse transform (simplified for target only)
5# Create a dummy array matching the shape scaler was fit on
6# (assuming scaler was fit on multiple columns from num_cols_to_scale)
7dummy_for_inv_transform = np.zeros((len(val_predictions_scaled.flatten()), len(num_cols_to_scale)))
8
9# Find the index of the target column in the original list of scaled columns
10target_idx_in_scaler = num_cols_to_scale.index(target_col)
11
12# Populate the target column in the dummy array for inverse transform
13dummy_for_inv_transform[:, target_idx_in_scaler] = val_predictions_scaled.flatten()
14val_predictions_inv = scaler.inverse_transform(dummy_for_inv_transform)[:, target_idx_in_scaler]
15val_predictions_final = val_predictions_inv.reshape(val_predictions_scaled.shape)
16
17# Inverse transform actuals
18dummy_for_inv_transform_actual = np.zeros((len(y_val.flatten()), len(num_cols_to_scale)))
19dummy_for_inv_transform_actual[:, target_idx_in_scaler] = y_val.flatten()
20val_actuals_inv = scaler.inverse_transform(dummy_for_inv_transform_actual)[:, target_idx_in_scaler]
21val_actuals_final = val_actuals_inv.reshape(y_val.shape)
22
23print("Predictions and actuals inverse transformed.")
24
25# --- Visualization (for the first item ID in validation set) ---
26# Get the encoded ItemID from the validation static data
27first_val_item_id_encoded = X_val_static[0, static_cols_for_model.index('ItemID_encoded')]
28# Convert back to original string ID for display if desired
29# original_item_id_str = le_item_id.inverse_transform([int(first_val_item_id_encoded)])[0]
30
31item_mask_val = (X_val_static[:, static_cols_for_model.index('ItemID_encoded')] == first_val_item_id_encoded)
32item_preds = val_predictions_final[item_mask_val, 0, 0]
33item_actuals = val_actuals_final[item_mask_val, 0, 0]
34
35plt.figure(figsize=(12, 6))
36plt.plot(item_actuals,
37 label=f'Actual (Item Encoded: {int(first_val_item_id_encoded)})',
38 marker='o', linestyle='--')
39plt.plot(item_preds,
40 label=f'Predicted (Item Encoded: {int(first_val_item_id_encoded)})',
41 marker='x')
42plt.title(f'Stricter TFT Point Forecast (Validation Item - Inverse Transformed)')
43plt.xlabel('Sequence Index in Validation Set for this Item')
44plt.ylabel('Value (Inverse Transformed)')
45plt.legend(); plt.grid(True); plt.tight_layout()
46# plt.savefig("docs/source/images/forecasting_tft_required_inputs.png")
47plt.show()
48print("Plot generated.")
Example Output Plot:
Visualization of the point forecast against actual validation data using the stricter TFT model.¶