fusionlab.nn.utils.prepare_model_inputs¶
- fusionlab.nn.utils.prepare_model_inputs(dynamic_input, static_input=None, future_input=None, model_type='strict', forecast_horizon=None, verbose=0, **kwargs)[source]¶
Prepares a list of input tensors for a model’s call method.
This function standardizes the creation of the input list [static, dynamic, future] expected by many models in
fusionlab. It handles cases where static or future inputs might beNone, creating appropriate dummy tensors with zero features if the model_type is ‘strict’.- Parameters:
dynamic_input (
np.ndarrayortf.Tensor) – The dynamic (past observed) features. This input is always required and must be a valid tensor or array. Expected shape: (batch_size, past_time_steps, num_dynamic_features).static_input (
np.ndarrayortf.Tensor, optional) – The static (time-invariant) features. Expected shape: (batch_size, num_static_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 static features will be created. Default isNone.future_input (
np.ndarrayortf.Tensor, optional) – The known future features. Expected shape: (batch_size, future_time_span, num_future_features). IfNoneand model_type is ‘strict’, a dummy tensor with 0 future features will be created. The time span for this dummy future tensor will be past_time_steps (from dynamic_input) plus forecast_horizon if provided, otherwise just past_time_steps. Default isNone.model_type (
{'strict', 'flexible'}, default'strict') –Determines how
Noneinputs for static and future features are handled: -'strict': If static_input or future_input isNone, a dummy tensor with a feature dimension of 0 will be created and included in the output list. This is for models that expect a 3-element list of tensors, even if some paths are unused.'flexible': If static_input or future_input isNone,Noneitself will be placed in the corresponding position in the output list. This is for models that can internally handleNoneinputs for optional feature types.
forecast_horizon (
int, optional) – The forecast horizon. Used only if model_type=’strict’ and future_input isNone, to determine the time dimension of the dummy future tensor (as past_time_steps + forecast_horizon). If not provided in this scenario, the dummy future tensor’s time dimension will match dynamic_input’s past_time_steps. Default isNone.verbose (
int, default0) – Verbosity level. If > 0, prints information about dummy tensor creation. -0: Silent. -1: Basic info on dummy creation. -2: More details on shapes.
- Returns:
A list containing three elements in the order: [processed_static_input, processed_dynamic_input, processed_future_input]. Elements can be TensorFlow tensors or
None(if model_type=’flexible’ and original input wasNone). All returned tensors are cast to tf.float32.- Return type:
List[Optional[tf.Tensor]]- Raises:
ValueError – If dynamic_input is
None. If dynamic_input is not at least 2D (needs batch dimension). If static_input (when provided) is not 2D. If future_input (when provided) is not 3D.TypeError – If inputs cannot be converted to TensorFlow tensors.
Examples
>>> import tensorflow as tf >>> import numpy as np >>> from fusionlab.nn.utils import prepare_model_inputs >>> B, T, H = 2, 10, 3 >>> D_s, D_d, D_f = 2, 4, 1 >>> dyn_in = tf.random.normal((B, T, D_d)) >>> stat_in = tf.random.normal((B, D_s)) >>> fut_in = tf.random.normal((B, T + H, D_f))
>>> # Strict mode, all inputs provided >>> s, d, f = prepare_model_inputs(dyn_in, stat_in, fut_in, model_type='strict') >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 2), D: (2, 10, 4), F: (2, 13, 1)
>>> # Strict mode, static is None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=fut_in, ... model_type='strict', forecast_horizon=H) >>> print(f"S: {s.shape}, D: {d.shape}, F: {f.shape}") S: (2, 0), D: (2, 10, 4), F: (2, 13, 1)
>>> # Flexible mode, static and future are None >>> s, d, f = prepare_model_inputs(dyn_in, static_input=None, future_input=None, ... model_type='flexible') >>> print(f"S: {s is None}, D: {d.shape}, F: {f is None}") S: True, D: (2, 10, 4), F: True