fusionlab.nn.forecast_tuner.TFTTuner

class fusionlab.nn.forecast_tuner.TFTTuner[source]

Bases: BaseTuner

Hyper‑parameter optimiser for Temporal Fusion Transformer (TFT) architectures – both the stricter TFT and the flexible TemporalFusionTransformer ('tft_flex').

The tuner explores LSTM depth, attention heads, hidden size, drop‑out, batch‑norm toggles, and optimiser learning rate. Every candidate batch size launches its own search loop followed by a refit on the champion trial, yielding a single best model across all batch sizes.

Objective

\[\theta^* \;=\; \arg\min_{\theta\in\Theta}\; L_{\text{val}}\bigl(f_{\theta}(\mathbf X),\,\mathbf y\bigr)\]

with \(\Theta\) the joint hyper‑parameter space and \(L_{\text{val}}\) the validation loss.

Supported model aliases

  • "tft"

  • "tft_flex"

param model_name:

Identifier of the model variant to tune. Must match one of the names accepted by the respective tuner class. Case is ignored. Defaults to "xtft" for XTFTTuner and "tft" for TFTTuner. Validation occurs before the base class initialiser is called.

type model_name:

str, optional

param param_space:

Dictionary mapping hyper‑parameter names to search options understood by Keras Tuner (e.g. lists, ranges, Int/Float distributions). When None a built‑in default space is employed.

type param_space:

dict, optional

param max_trials:

Upper bound on the number of trial configurations that the tuner explores. Must be a positive integer.

type max_trials:

int, default 10

param objective:

Metric name that the tuner seeks to minimise (or maximise if prefixed with 'max'). Any Keras history key is valid.

type objective:

str, default :class:``’val_loss’:class:``

param epochs:

Training epochs for the refit phase carried out on the best hyper‑parameters of each batch‑size loop.

type epochs:

int, default 10

param batch_sizes:

Ensemble of batch sizes to iterate over. A separate tuning run is executed for every value.

type batch_sizes:

list[int], default [32]

param validation_split:

Fraction of the training data reserved for validation inside both the search and refit stages. Must fall in (0, 1).

type validation_split:

float, default 0.2

param tuner_dir:

Root directory where Keras Tuner artefacts are written (trial summaries, checkpoints, logs). A path within the current working directory is autogenerated if omitted.

type tuner_dir:

str, optional

param project_name:

Folder name under tuner_dir used to isolate results of one tuning job. Defaults to a slug derived from the model type and run description.

type project_name:

str, optional

param tuner_type:

Search strategy. ‘random’ draws configurations uniformly; ‘bayesian’ performs probabilistic optimisation of the objective.

type tuner_type:

{'random', 'bayesian'}, default :class:``’random’:class:``

param callbacks:

Extra Keras callbacks active during both the search and refit phases. When None a sensible EarlyStopping is injected automatically.

type callbacks:

list[keras.callbacks.Callback], optional

param model_builder:

Custom factory returning a compiled Keras model from a hyper‑parameter set. If missing an internal builder covering the canonical search space is substituted.

type model_builder:

Callable[[kt.HyperParameters], `:class:`Model], optional

param verbose:

Controls console logging produced by the tuner wrapper: 0 = silent · 1 = high‑level · 2 = per‑step details · ≥3 = debug.

type verbose:

int, default 1

param **kws:

Extra keyword arguments forwarded to the base tuner.

type **kws:

Any

Example

Create synthetic data, instantiate the tuner, and run a full search.

>>> import numpy as np, tensorflow as tf
>>> from fusionlab.forecast_tuner import TFTTuner
>>> B, F, Ns, Nd, Nf, O = 128, 6, 3, 4, 2, 1
>>> rng = np.random.default_rng(42)
>>> X_static  = rng.normal(size=(B, Ns)).astype("float32")
>>> X_dynamic = rng.normal(size=(B, F, Nd)).astype("float32")
>>> X_future  = rng.normal(size=(B, F, Nf)).astype("float32")
>>> y         = rng.normal(size=(B, F, O)).astype("float32")
>>>
>>> tuner = TFTTuner(
...     model_name="tft_flex",
...     max_trials=4,
...     epochs=30,
...     batch_sizes=[32, 64],
...     tuner_type="bayesian",
...     verbose=2,
... )
>>> best_hps, best_model, kt_obj = tuner.fit(
...     inputs=[X_static, X_dynamic, X_future],
...     y=y,
...     forecast_horizon=F,
... )
>>> print("Best learning‑rate:",
...       f"{best_hps['learning_rate']:.3g}")

See also

XTFTTuner

Companion tuner for Extreme TFT variants.

fusionlab.nn.transformers.TFT

Strict reference implementation.

fusionlab.nn.transformers.TemporalFusionTransformer

Flexible implementation accepting missing input blocks.

References

__init__(model_name='tft', param_space=None, max_trials=10, objective='val_loss', epochs=10, batch_sizes=[32], validation_split=0.2, tuner_dir=None, project_name=None, tuner_type='random', callbacks=None, model_builder=None, verbose=1, **kws)[source]
Parameters:
  • model_name (str)

  • param_space (Dict[str, Any] | None)

  • max_trials (int)

  • objective (str)

  • epochs (int)

  • batch_sizes (List[int])

  • validation_split (float)

  • tuner_dir (str | None)

  • project_name (str | None)

  • tuner_type (str)

  • callbacks (List[Callable] | None)

  • model_builder (Callable | None)

  • verbose (int)

  • kws (Any)

Methods

__init__([model_name, param_space, ...])

cast_multiple_bool_params(params, ...)

Cast several boolean hyperparameters at once.

fit(inputs, y[, forecast_horizon, ...])

Execute the complete tuning workflow.

__init__(model_name='tft', param_space=None, max_trials=10, objective='val_loss', epochs=10, batch_sizes=[32], validation_split=0.2, tuner_dir=None, project_name=None, tuner_type='random', callbacks=None, model_builder=None, verbose=1, **kws)[source]
Parameters:
  • model_name (str)

  • param_space (Dict[str, Any] | None)

  • max_trials (int)

  • objective (str)

  • epochs (int)

  • batch_sizes (List[int])

  • validation_split (float)

  • tuner_dir (str | None)

  • project_name (str | None)

  • tuner_type (str)

  • callbacks (List[Callable] | None)

  • model_builder (Callable | None)

  • verbose (int)

  • kws (Any)