Quickstart

This guide provides a minimal example to get you started with training a basic forecasting model using fusionlab-learn.

Prerequisites

Make sure you have installed fusionlab-learn and its core dependencies, including TensorFlow. If not, please follow the Installation guide first.

Steps

Let’s train a simple TemporalFusionTransformer for point forecasting using only dynamic (past) inputs.

  1. Import Libraries We need TensorFlow, NumPy, and the model class.

    1import tensorflow as tf
    2import numpy as np
    3from fusionlab.nn import TemporalFusionTransformer
    4
    5# Optional: Suppress TensorFlow warnings for cleaner output
    6tf.get_logger().setLevel('ERROR')
    7tf.autograph.set_verbosity(0)
    
  2. Prepare Dummy Data We’ll create some random data simulating dynamic features and a target variable.

     1# Define data dimensions
     2batch_size = 16
     3num_past_timesteps = 20  # Length of historical input sequence
     4dynamic_feature_dim = 3  # Number of dynamic features
     5forecast_horizon = 5     # Number of steps to predict
     6
     7# Generate random dynamic (past) input data
     8# Shape: (batch_size, num_past_timesteps, dynamic_feature_dim)
     9X_dynamic = np.random.rand(
    10    batch_size, num_past_timesteps, dynamic_feature_dim
    11).astype(np.float32)
    12
    13# Generate random target data (what we want to predict)
    14# Shape: (batch_size, forecast_horizon, 1) -> Point forecast (1 value per step)
    15y_target = np.random.rand(
    16    batch_size, forecast_horizon, 1
    17).astype(np.float32)
    18
    19print(f"Dynamic Input Shape: {X_dynamic.shape}")
    20print(f"Target Output Shape: {y_target.shape}")
    
  3. Instantiate the Model Create an instance of TemporalFusionTransformer. Since we are only using dynamic inputs, we only need to specify dynamic_input_dim. We also set the forecast_horizon. We omit quantiles for point forecasting.

     1model = TemporalFusionTransformer(
     2    dynamic_input_dim=dynamic_feature_dim,
     3    forecast_horizon=forecast_horizon,
     4    # Using default values for other parameters like:
     5    # static_input_dim=None,
     6    # future_input_dim=None,
     7    # hidden_units=32,
     8    # num_heads=4,
     9    # quantiles=None, # Default is point forecast
    10    # etc.
    11)
    12
    13# Optional: Build the model by passing a sample input shape or data
    14# This is needed before summary() or plotting can work.
    15# Note: Input must be a tuple, even with only one element.
    16model.build(input_shape=[(None, num_past_timesteps, dynamic_feature_dim)])
    17model.summary()
    
  4. Compile the Model Specify the optimizer and loss function. For point forecasting, Mean Squared Error (‘mse’) is a common choice.

    1model.compile(optimizer='adam', loss='mse')
    
  5. Train the Model Fit the model to the dummy data for a few epochs.

    1print("\nTraining the model...")
    2history = model.fit(
    3    x=(X_dynamic,), # Input must be a tuple
    4    y=y_target,
    5    epochs=3,       # Use few epochs for a quick demo
    6    batch_size=4,
    7    verbose=1       # Show progress
    8)
    9print("Training complete.")
    
  6. Make Predictions Use the trained model to generate forecasts on new data (or the same data in this example).

    1print("\nMaking predictions...")
    2# Use the same input data for prediction in this example
    3predictions = model.predict((X_dynamic,))
    4
    5print(f"Predictions output shape: {predictions.shape}")
    6# Expected shape: (batch_size, forecast_horizon, 1)
    

Conclusion

This quickstart demonstrated the basic workflow: preparing data, instantiating a model, compiling it, training it, and making predictions.

For more advanced use cases involving static/future features, quantile forecasts, anomaly detection, or the XTFT model, please refer to the User Guide and the API Reference.