Exercise: Advanced Forecasting with BaseAttentive¶
Welcome to this exercise on using the BaseAttentive class, the
foundational forecasting engine within fusionlab-learn. This
exercise will demonstrate how to leverage its powerful, purely
data-driven architecture to perform a complex, multi-horizon,
probabilistic forecast.
We will focus on two key aspects:
1. Using the architecture_config dictionary to build a custom
model structure (e.g., a pure transformer).
Performing a quantile forecast to generate not just a single prediction, but a full distribution to quantify uncertainty.
Learning Objectives:
Generate a synthetic dataset with static, dynamic past, and known future features.
Use the
architecture_configdictionary to define the model’s internal structure.Prepare and reshape the data into the three separate input arrays required by
BaseAttentive.Instantiate, compile, and train the model for probabilistic (quantile) forecasting.
Visualize the multi-step forecast results, including the uncertainty bounds.
Let’s begin!
Prerequisites¶
Ensure you have fusionlab-learn and its common dependencies
installed.
pip install fusionlab-learn matplotlib scikit-learn
Step 1: Imports and Setup¶
First, we import all necessary libraries and set up our environment for reproducibility.
1import os
2import numpy as np
3import pandas as pd
4import tensorflow as tf
5import matplotlib.pyplot as plt
6from sklearn.preprocessing import StandardScaler
7
8# FusionLab imports
9from fusionlab.nn.models import BaseAttentive
10from fusionlab.nn.utils import reshape_xtft_data
11
12# Suppress warnings and TF logs for cleaner output
13import warnings
14warnings.filterwarnings('ignore')
15tf.get_logger().setLevel('ERROR')
16
17# Directory for saving any output images
18EXERCISE_OUTPUT_DIR = "./base_attentive_exercise_outputs"
19os.makedirs(EXERCISE_OUTPUT_DIR, exist_ok=True)
20
21print("Libraries imported and setup complete for BaseAttentive exercise.")
Expected Output:
Libraries imported and setup complete for BaseAttentive exercise.
Step 2: Generate Synthetic Data¶
We’ll create a synthetic dataset for multiple time series, each with
static, dynamic, and future-known features—the exact structure that
BaseAttentive is designed to handle.
1# Configuration
2N_ITEMS = 5
3N_TIMESTEPS_PER_ITEM = 120
4SEED = 42
5np.random.seed(SEED)
6tf.random.set_seed(SEED)
7
8# --- Generate Data ---
9df_list = []
10date_rng = pd.date_range(start='2023-01-01', periods=N_TIMESTEPS_PER_ITEM, freq='D')
11
12for item_id in range(N_ITEMS):
13 time_idx = np.arange(N_TIMESTEPS_PER_ITEM)
14 # Base signal with trend and seasonality
15 value = (
16 50 + item_id * 15 + time_idx * 0.2
17 + np.sin(time_idx / 14) * 20
18 + np.random.normal(0, 4, N_TIMESTEPS_PER_ITEM)
19 )
20 # Known future feature (e.g., special event)
21 future_promo = np.sin(time_idx / 7) > 0.95
22
23 item_df = pd.DataFrame({
24 'Date': date_rng,
25 'ItemID': item_id,
26 'Value': value,
27 'DayOfWeek': date_rng.dayofweek,
28 'Month': date_rng.month,
29 'FuturePromo': future_promo.astype(int)
30 })
31 # Dynamic feature (lagged value)
32 item_df['ValueLag1'] = item_df['Value'].shift(1)
33 df_list.append(item_df)
34
35df_raw = pd.concat(df_list).dropna().reset_index(drop=True)
36print(f"Generated raw data shape: {df_raw.shape}")
37print("Sample of generated data:")
38print(df_raw.head())
Expected Output:
Generated raw data shape: (595, 7)
Sample of generated data:
Date ItemID Value DayOfWeek Month FuturePromo ValueLag1
0 2023-01-02 0 51.074300 0 1 0 51.986857
1 2023-01-03 0 55.838189 1 1 0 51.074300
2 2023-01-04 0 60.945110 2 1 0 55.838189
3 2023-01-05 0 55.500244 3 1 0 60.945110
4 2023-01-06 0 57.055428 4 1 0 55.500244
Step 3: Preprocess and Reshape Data¶
We define the roles for each column and then use the
reshape_xtft_data utility to transform the flat dataframe into
the three sequence arrays required by the model’s tft_like mode.
1# Define feature roles
2static_cols = ['ItemID']
3dynamic_cols = ['ValueLag1']
4future_cols = ['DayOfWeek', 'Month', 'FuturePromo']
5target_col = 'Value'
6
7# Scale numerical features
8df_processed = df_raw.copy()
9scaler_val = StandardScaler()
10scaler_lag1 = StandardScaler()
11
12df_processed[target_col] = scaler_val.fit_transform(
13 df_processed[[target_col]]
14)
15df_processed['ValueLag1'] = scaler_lag1.fit_transform(
16 df_processed[['ValueLag1']]
17 )
18scaler = StandardScaler()
19df_processed[['Value', 'ValueLag1']] = scaler.fit_transform(
20 df_processed[['Value', 'ValueLag1']]
21)
22
23# Reshape data into sequences
24TIME_STEPS = 21 # Lookback window
25FORECAST_HORIZON = 7 # Prediction window
26
27static_data, dynamic_data, future_data, targets = reshape_xtft_data(
28 df=df_processed,
29 dt_col='Date',
30 target_col=target_col,
31 dynamic_cols=dynamic_cols,
32 static_cols=static_cols,
33 future_cols=future_cols,
34 spatial_cols=['ItemID'], # Group by item
35 time_steps=TIME_STEPS,
36 forecast_horizons=FORECAST_HORIZON
37)
38
39print(f"\nReshaped Data Shapes for 'tft_like' mode:")
40print(f" Static data: {static_data.shape}")
41print(f" Dynamic data: {dynamic_data.shape}")
42# Note: future_data length = TIME_STEPS + FORECAST_HORIZON
43print(f" Future data: {future_data.shape}")
44print(f" Target data: {targets.shape}")
Expected Output:
[INFO] Reshaping time‑series data into rolling sequences...
[INFO] Data grouped by ['ItemID'] into 5 groups.
[INFO] Total valid sequences to be generated: 460
[INFO] Final data shapes after reshaping:
[DEBUG] Static Data : (460, 1)
[DEBUG] Dynamic Data: (460, 21, 1)
[DEBUG] Future Data : (460, 28, 3)
[DEBUG] Target Data : (460, 7, 1)
[INFO] Time‑series data successfully reshaped into rolling sequences.
Reshaped Data Shapes for 'tft_like' mode:
Static data: (460, 1)
Dynamic data: (460, 21, 1)
Future data: (460, 28, 3)
Target data: (460, 7, 1)
Step 4: Define, Compile, and Train the Model¶
Now we instantiate BaseAttentive. We use architecture_config to
specify a pure transformer architecture and set quantiles to
enable probabilistic forecasting.
1# Split data into training and validation sets
2val_split = -50
3train_inputs = [arr[:val_split] for arr in [static_data, dynamic_data, future_data]]
4val_inputs = [arr[val_split:] for arr in [static_data, dynamic_data, future_data]]
5train_targets, val_targets = targets[:val_split], targets[val_split:]
6
7# Define a pure transformer architecture
8tfmr_config = {
9 'encoder_type': 'transformer',
10 'decoder_attention_stack': ['cross', 'hierarchical'],
11 'feature_processing': 'dense'
12}
13# Define quantiles for probabilistic forecast
14output_quantiles = [0.1, 0.5, 0.9]
15
16# Instantiate the model
17model = BaseAttentive(
18 static_input_dim=static_data.shape[-1],
19 dynamic_input_dim=dynamic_data.shape[-1],
20 future_input_dim=future_data.shape[-1],
21 output_dim=1,
22 forecast_horizon=FORECAST_HORIZON,
23 max_window_size=TIME_STEPS,
24 mode='tft_like',
25 quantiles=output_quantiles,
26 architecture_config=tfmr_config,
27 hidden_units=32,
28 attention_units=32
29)
30
31# Compile with a quantile loss function
32def quantile_loss(y_true, y_pred):
33 q = tf.constant(np.array(output_quantiles), dtype=tf.float32)
34 e = y_true - y_pred
35 return tf.keras.backend.mean(
36 tf.keras.backend.maximum(q * e, (q - 1) * e), axis=-1
37 )
38
39model.compile(optimizer='adam', loss=quantile_loss)
40
41# Train the model
42print("\nStarting BaseAttentive model training...")
43history = model.fit(
44 train_inputs, train_targets,
45 validation_data=(val_inputs, val_targets),
46 epochs=20, batch_size=64, verbose=0
47)
48print("Training complete.")
49print(f"Final validation loss: {history.history['val_loss'][-1]:.4f}")
Expected Output:
Starting BaseAttentive model training...
Training complete.
Final validation loss: 0.5504
Step 5: Visualize the Probabilistic Forecast¶
This is the most exciting part. We’ll make predictions on the validation set and plot the results, showing the median forecast along with the 80% confidence interval (the area between the 0.1 and 0.9 quantiles).
1# Make predictions on the validation set
2val_preds = model.predict(val_inputs)
3
4# Select a single sequence from the validation set to plot
5idx_to_plot = 10
6median_pred = val_preds[idx_to_plot, :, 1].ravel() # 0.5 quantile is at index 1
7lower_bound = val_preds[idx_to_plot, :, 0].ravel() # 0.1 quantile is at index 0
8upper_bound = val_preds[idx_to_plot, :, 2].ravel() # 0.9 quantile is at index 2
9actuals = val_targets[idx_to_plot, :, 0].ravel()
10
11# --- Visualization ---
12plt.figure(figsize=(12, 6))
13# Plot uncertainty bounds
14plt.fill_between(
15 range(FORECAST_HORIZON), lower_bound, upper_bound,
16 color='orange', alpha=0.3, label='80% Prediction Interval'
17)
18# Plot actuals and median forecast
19plt.plot(actuals, label='Actual Values', marker='o', linestyle='--')
20plt.plot(median_pred, label='Median Forecast (p50)', marker='x')
21
22plt.title('Probabilistic Forecast vs. Actual (Validation Sample)')
23plt.xlabel(f'Forecast Step (Horizon = {FORECAST_HORIZON} steps)')
24plt.ylabel('Normalized Value')
25plt.legend()
26plt.grid(True, linestyle=':')
27plt.tight_layout()
28plt.show()
Expected Plot:
A plot showing the actual values, the median (p50) forecast, and the shaded 80% prediction interval. This visualizes not just what the model predicts, but also its confidence in that prediction.¶
Discussion of Exercise¶
Congratulations! You have successfully built, trained, and evaluated an
advanced forecasting model using the BaseAttentive engine. In this
exercise, you have learned to:
- Structure a complex dataset with static, dynamic, and future
features for a sophisticated model.
- Use the
architecture_configdictionary to flexibly define the model’s internal structure (e.g., as a pure transformer).
- Use the
- Implement a probabilistic forecast by configuring output quantiles
and using a corresponding loss function.
- Visualize and interpret a probabilistic forecast, including its
uncertainty bounds.
This powerful, data-driven workflow forms the foundation for tackling some of the most challenging time series forecasting problems.