Version 0.2.3

(Release Date: May 25, 2025)

Focus: Object-Oriented Hyperparameter Tuning

This release marks a significant enhancement in how users can perform hyperparameter tuning within fusionlab-learn. We introduce a new, robust, class-based approach for forecast_tuner, offering improved structure, reusability, and flexibility. While the previous function-based approach remains available for backward compatibility, these new classes represent the recommended path forward for optimizing your forecasting models.

Enhancements & Improvements

  • New Class-Based Tuners:
    • New Introduced BaseTuner: An internal base class (BaseTuner) that encapsulates the core Keras Tuner logic, including input validation, model building, the tuning loop, and logging. This provides a consistent and extensible foundation.

    • New Introduced XTFTTuner: A dedicated tuner class inheriting from _BaseTuner for optimizing XTFT and SuperXTFT models. It simplifies the setup and execution of tuning experiments for these specific architectures.

    • New Introduced TFTTuner: A dedicated tuner class for TFT (stricter) and TemporalFusionTransformer (tft_flex) models, providing clear validation and setup for these variants.

  • Improved Tuning Workflow:
    • Enhancement The new class-based approach separates tuner configuration (__init__) from the tuning execution (fit). This allows users to instantiate a tuner once and call fit multiple times with different datasets or task parameters (like forecast_horizon and quantiles), promoting code reuse.

    • Enhancement The internal _model_builder_factory() (now part of BaseTuner) remains the default, providing a robust mechanism for building models during tuning, but users can still supply a custom_model_builder for maximum control.

    • Enhancement Improved handling and validation of input tensors within the base class, including automatic creation of dummy tensors when needed for tft_flex.

  • Code Example (New Class-Based Approach):

     1import numpy as np
     2from fusionlab.nn.forecast_tuner import XTFTTuner
     3
     4# 1. Prepare Dummy Data
     5B, T_past, H_out = 8, 12, 6
     6D_s, D_d, D_f = 3, 5, 2
     7T_future_total = T_past + H_out
     8X_s = np.random.rand(B, D_s).astype(np.float32)
     9X_d = np.random.rand(B, T_past, D_d).astype(np.float32)
    10X_f = np.random.rand(B, T_future_total, D_f).astype(np.float32)
    11y = np.random.rand(B, H_out, 1).astype(np.float32)
    12train_inputs = [X_s, X_d, X_f]
    13
    14# 2. Instantiate the Tuner
    15tuner = XTFTTuner(
    16    model_name="xtft",
    17    max_trials=3,      # Keep low for demo
    18    epochs=2,          # Keep low for demo
    19    batch_sizes=[8],   # Single batch for demo
    20    tuner_dir="./xtft_class_tuning_v023",
    21    verbose=0          # Suppress detailed Keras Tuner logs
    22)
    23
    24# 3. Run the Tuning
    25print("Starting XTFT tuning with new class-based approach...")
    26best_hps, best_model, _ = tuner.fit(
    27    inputs=train_inputs,
    28    y=y,
    29    forecast_horizon=H_out
    30)
    31
    32# 4. Use results
    33if best_hps:
    34    print("Tuning successful!")
    35    print(f"Best Batch Size: {best_hps.get('batch_size')}")
    36    print(f"Best Learning Rate: {best_hps.get('learning_rate')}")
    37else:
    38    print("Tuning did not find a best model.")
    

Fixes

  • Fix Improved robustness in the _model_builder_factory by using _get_valid_kwargs to ensure only parameters accepted by the specific model’s __init__ are passed during instantiation.

  • Fix Enhanced validation within _prepare_inputs to provide clearer error messages for missing or incorrectly shaped inputs, especially for tft vs tft_flex requirements.

Tests

  • Tests Added a comprehensive suite of unit tests for the new _BaseTuner, XTFTTuner, and TFTTuner classes, covering initialization, input preparation, fit execution, and result retrieval.

  • Tests Included tests to ensure model_name validation works correctly in XTFTTuner and TFTTuner.

Documentation

  • Docs Added a new User Guide page: /user_guide/forecast_tuner/forecast_tuner_class_based detailing the new object-oriented approach to hyperparameter tuning using XTFTTuner and TFTTuner, including code examples.

  • Docs Updated the existing /user_guide/forecast_tuner/forecast_tuner page to acknowledge the new class-based approach and link to it, while retaining the documentation for the function-based method (which remains available in v0.2.3).

  • Docs Added API references for XTFTTuner and TFTTuner.

Contributors