Quickstart¶
This guide provides a minimal example to get you started with
training a basic forecasting model using fusionlab-learn.
Prerequisites¶
Make sure you have installed fusionlab-learn and its core
dependencies, including TensorFlow. If not, please follow the
Installation guide first.
Steps¶
Let’s train a simple TemporalFusionTransformer for point forecasting using only dynamic (past) inputs.
Import Libraries We need TensorFlow, NumPy, and the model class.
1import tensorflow as tf 2import numpy as np 3from fusionlab.nn import TemporalFusionTransformer 4 5# Optional: Suppress TensorFlow warnings for cleaner output 6tf.get_logger().setLevel('ERROR') 7tf.autograph.set_verbosity(0)
Prepare Dummy Data We’ll create some random data simulating dynamic features and a target variable.
1# Define data dimensions 2batch_size = 16 3num_past_timesteps = 20 # Length of historical input sequence 4dynamic_feature_dim = 3 # Number of dynamic features 5forecast_horizon = 5 # Number of steps to predict 6 7# Generate random dynamic (past) input data 8# Shape: (batch_size, num_past_timesteps, dynamic_feature_dim) 9X_dynamic = np.random.rand( 10 batch_size, num_past_timesteps, dynamic_feature_dim 11).astype(np.float32) 12 13# Generate random target data (what we want to predict) 14# Shape: (batch_size, forecast_horizon, 1) -> Point forecast (1 value per step) 15y_target = np.random.rand( 16 batch_size, forecast_horizon, 1 17).astype(np.float32) 18 19print(f"Dynamic Input Shape: {X_dynamic.shape}") 20print(f"Target Output Shape: {y_target.shape}")
Instantiate the Model Create an instance of TemporalFusionTransformer. Since we are only using dynamic inputs, we only need to specify dynamic_input_dim. We also set the forecast_horizon. We omit quantiles for point forecasting.
1model = TemporalFusionTransformer( 2 dynamic_input_dim=dynamic_feature_dim, 3 forecast_horizon=forecast_horizon, 4 # Using default values for other parameters like: 5 # static_input_dim=None, 6 # future_input_dim=None, 7 # hidden_units=32, 8 # num_heads=4, 9 # quantiles=None, # Default is point forecast 10 # etc. 11) 12 13# Optional: Build the model by passing a sample input shape or data 14# This is needed before summary() or plotting can work. 15# Note: Input must be a tuple, even with only one element. 16model.build(input_shape=[(None, num_past_timesteps, dynamic_feature_dim)]) 17model.summary()
Compile the Model Specify the optimizer and loss function. For point forecasting, Mean Squared Error (‘mse’) is a common choice.
1model.compile(optimizer='adam', loss='mse')
Train the Model Fit the model to the dummy data for a few epochs.
1print("\nTraining the model...") 2history = model.fit( 3 x=(X_dynamic,), # Input must be a tuple 4 y=y_target, 5 epochs=3, # Use few epochs for a quick demo 6 batch_size=4, 7 verbose=1 # Show progress 8) 9print("Training complete.")
Make Predictions Use the trained model to generate forecasts on new data (or the same data in this example).
1print("\nMaking predictions...") 2# Use the same input data for prediction in this example 3predictions = model.predict((X_dynamic,)) 4 5print(f"Predictions output shape: {predictions.shape}") 6# Expected shape: (batch_size, forecast_horizon, 1)
Conclusion¶
This quickstart demonstrated the basic workflow: preparing data, instantiating a model, compiling it, training it, and making predictions.
For more advanced use cases involving static/future features, quantile forecasts, anomaly detection, or the XTFT model, please refer to the User Guide and the API Reference.