Version 0.2.3¶
(Release Date: May 25, 2025)
Focus: Object-Oriented Hyperparameter Tuning
This release marks a significant enhancement in how users can perform
hyperparameter tuning within fusionlab-learn. We introduce a new,
robust, class-based approach for forecast_tuner, offering
improved structure, reusability, and flexibility. While the
previous function-based approach remains available for backward
compatibility, these new classes represent the recommended path
forward for optimizing your forecasting models.
Enhancements & Improvements¶
- New Class-Based Tuners:
New Introduced
BaseTuner: An internal base class (BaseTuner) that encapsulates the core Keras Tuner logic, including input validation, model building, the tuning loop, and logging. This provides a consistent and extensible foundation.New Introduced
XTFTTuner: A dedicated tuner class inheriting from _BaseTuner for optimizingXTFTandSuperXTFTmodels. It simplifies the setup and execution of tuning experiments for these specific architectures.New Introduced
TFTTuner: A dedicated tuner class forTFT(stricter) andTemporalFusionTransformer(tft_flex) models, providing clear validation and setup for these variants.
- Improved Tuning Workflow:
Enhancement The new class-based approach separates tuner configuration (__init__) from the tuning execution (fit). This allows users to instantiate a tuner once and call fit multiple times with different datasets or task parameters (like forecast_horizon and quantiles), promoting code reuse.
Enhancement The internal
_model_builder_factory()(now part ofBaseTuner) remains the default, providing a robust mechanism for building models during tuning, but users can still supply a custom_model_builder for maximum control.Enhancement Improved handling and validation of input tensors within the base class, including automatic creation of dummy tensors when needed for tft_flex.
Code Example (New Class-Based Approach):
1import numpy as np 2from fusionlab.nn.forecast_tuner import XTFTTuner 3 4# 1. Prepare Dummy Data 5B, T_past, H_out = 8, 12, 6 6D_s, D_d, D_f = 3, 5, 2 7T_future_total = T_past + H_out 8X_s = np.random.rand(B, D_s).astype(np.float32) 9X_d = np.random.rand(B, T_past, D_d).astype(np.float32) 10X_f = np.random.rand(B, T_future_total, D_f).astype(np.float32) 11y = np.random.rand(B, H_out, 1).astype(np.float32) 12train_inputs = [X_s, X_d, X_f] 13 14# 2. Instantiate the Tuner 15tuner = XTFTTuner( 16 model_name="xtft", 17 max_trials=3, # Keep low for demo 18 epochs=2, # Keep low for demo 19 batch_sizes=[8], # Single batch for demo 20 tuner_dir="./xtft_class_tuning_v023", 21 verbose=0 # Suppress detailed Keras Tuner logs 22) 23 24# 3. Run the Tuning 25print("Starting XTFT tuning with new class-based approach...") 26best_hps, best_model, _ = tuner.fit( 27 inputs=train_inputs, 28 y=y, 29 forecast_horizon=H_out 30) 31 32# 4. Use results 33if best_hps: 34 print("Tuning successful!") 35 print(f"Best Batch Size: {best_hps.get('batch_size')}") 36 print(f"Best Learning Rate: {best_hps.get('learning_rate')}") 37else: 38 print("Tuning did not find a best model.")
Fixes¶
Fix Improved robustness in the _model_builder_factory by using _get_valid_kwargs to ensure only parameters accepted by the specific model’s __init__ are passed during instantiation.
Fix Enhanced validation within _prepare_inputs to provide clearer error messages for missing or incorrectly shaped inputs, especially for tft vs tft_flex requirements.
Tests¶
Tests Added a comprehensive suite of unit tests for the new _BaseTuner, XTFTTuner, and TFTTuner classes, covering initialization, input preparation, fit execution, and result retrieval.
Tests Included tests to ensure model_name validation works correctly in XTFTTuner and TFTTuner.
Documentation¶
Docs Added a new User Guide page: /user_guide/forecast_tuner/forecast_tuner_class_based detailing the new object-oriented approach to hyperparameter tuning using XTFTTuner and TFTTuner, including code examples.
Docs Updated the existing /user_guide/forecast_tuner/forecast_tuner page to acknowledge the new class-based approach and link to it, while retaining the documentation for the function-based method (which remains available in v0.2.3).
Contributors¶
Laurent Kouadio (Lead Developer)