fusionlab.nn.utils.reshape_xtft_data

fusionlab.nn.utils.reshape_xtft_data(df, dt_col, target_col, dynamic_cols, static_cols=None, future_cols=None, spatial_cols=None, time_steps=4, forecast_horizon=1, to_datetime=None, model='xtft', error='raise', savefile=None, forecast_horizons=None, verbose=1, **kw)[source]
Reshapes time series data into rolling sequences for models like

Temporal Fusion Transformer (TFT) and Extreme Temporal Fusion Transformer (XTFT).

This function transforms a Pandas DataFrame into a set of aligned sequences suitable for training or evaluating sequence-to-sequence models. It handles static, dynamic (past observed), and future known covariates, and can process data grouped by spatial or other identifiers. The core process involves validating inputs, optionally grouping data, sorting, calculating the total number of sequences, pre-allocating NumPy arrays for efficiency, and then populating these arrays by creating rolling windows. Future features are extracted for a combined period covering both the lookback (time_steps) and prediction (forecast_horizons) windows.

The core process involves: 1. Validating input DataFrame and specified columns. 2. Optionally grouping the data by spatial_cols. If not provided,

the entire DataFrame is treated as a single group.

  1. Sorting data within each group by the datetime column (dt_col).

  2. A two-pass system for efficiency: a. First Pass: Iterates through groups to determine the total

    number of valid sequences that can be generated based on time_steps and forecast_horizon. Groups too short to form even one complete input-target pair are skipped.

    1. Pre-allocation: Empty NumPy arrays are created with the final determined dimensions to store static, dynamic, future, and target sequences.

    2. Second Pass: Iterates through valid groups again, creating rolling windows and populating the pre-allocated arrays.

  3. For each window: * Static features (static_cols): Extracted once per group if

    spatial_cols are used, or from the first time step of the current input window if no spatial grouping.

    • Dynamic features (dynamic_cols): Extracted for the lookback

      period of time_steps.

    • Future features (future_cols): Extracted for a combined

      period covering both the time_steps (lookback) and the forecast_horizon. This provides the model with known future information relevant to both the input window and the prediction window.

    • Target features (target_col): Extracted for the

      forecast_horizon immediately following the input window.

  4. Optionally saves the generated sequence arrays and metadata to a .joblib file.

Parameters:
  • df (pandas.DataFrame) – The input DataFrame containing time series data. It must include a datetime column specified by dt_col and the target_col.

  • dt_col (str) – The name of the datetime column in df. This column is processed to ensure proper datetime formatting.

  • target_col (str) – The column in df holding the target values for forecasting.

  • dynamic_cols (List[str]) – A list of column names representing dynamic features that vary over time (e.g., past observed values, exogenous variables).

  • static_cols (List[str], optional) – A list of column names representing static features that are time-invariant for each group (if spatial_cols are used) or per sequence (if no spatial_cols). If None, no static data is generated. Default is None.

  • future_cols (List[str], optional) – A list of column names representing known future covariates. If None, no future data is generated. Default is None.

  • spatial_cols (str or List[str], optional) – Column name(s) used to group the DataFrame, typically by a spatial identifier (e.g., ‘location_id’, [‘longitude’, ‘latitude’]). If None, the entire DataFrame is treated as a single group. Default is None.

  • time_steps (int, default 4) – The number of past time steps to include in each input sequence (lookback window).

  • forecast_horizon (int, default 1) – The number of future time steps to predict for each input sequence.

  • to_datetime (str, optional) – Specifies a conversion rule for dt_col if it’s not already in datetime format (e.g., “auto”, “Y”, “M”, “D”). Passed to ts_validator(). Default is None.

  • model (str, default "xtft") – Indicates the target model type. Currently, this parameter is primarily for forward compatibility or specific internal checks and does not alter the core reshaping logic for numerical data. Supported: {“xtft”, “tft”, “any”, “lstm”, None}.

  • error (str, default 'raise') – Error handling strategy if required columns are missing. Options: {‘raise’, ‘warn’, ‘ignore’}.

  • savefile (str, optional) – If provided, path to save the generated sequence arrays and metadata as a .joblib file. Default is None.

  • verbose (int, default 1) –

    Verbosity level for logging and status messages: - 0: Silent. - 1: Basic information (e.g., grouping, total sequences,

    final shapes, save messages).

    • 2: More detailed processing steps (e.g., per-group sequence counts if spatial_cols used).

    • 3 (or higher): Debug-level information including internal shapes during sequence generation (not typically used).

  • **kw – Placeholder for forward‑compat extensions.

  • forecast_horizons (int | None)

Returns:

  • static_data_arr (numpy.ndarray or None) – Array of static feature sequences. Shape: \((\text{TotalSequences}, \text{NumStaticFeatures})\). Returns None if static_cols is not provided.

  • dynamic_data_arr (numpy.ndarray) – Array of dynamic feature sequences. Shape: \((\text{TotalSequences}, \text{time_steps}, \text{NumDynamicFeatures})\).

  • future_data_arr (numpy.ndarray or None) – Array of future covariate sequences. Shape: \((\text{TotalSequences}, \text{time_steps} + \text{forecast_horizons}, \text{NumFutureFeatures})\). Returns None if future_cols is not provided.

  • target_data_arr (numpy.ndarray) – Array of target value sequences. Shape: \((\text{TotalSequences}, \text{forecast_horizons}, 1)\).

Raises:
  • ValueError – If time_steps or forecast_horizons are not positive integers. If not enough data points exist in any group to create sequences. If required columns specified in *_cols arguments are missing and error=’raise’.

  • KeyError – If target_col or columns in dynamic_cols (or other provided *_cols) do not exist in the DataFrame and error=’raise’.

Return type:

Tuple[ndarray | None, ndarray, ndarray | None, ndarray]

Notes

  • The function sorts data by spatial_cols (if any) and then by dt_col before generating sequences. Ensure dt_col represents a sortable time progression.

  • For non-spatial data (spatial_cols=None), static features are extracted from the first time step of each generated input window.

  • The future_data_arr is constructed to span both the input lookback window (time_steps) and the prediction window (forecast_horizons), providing the model with all known future information relevant to the current input and target sequences.

The function constructs rolling windows. For an input sequence starting at index \(j\) within a group: - Dynamic Input \(\mathbf{X}^{(j)}_{dyn} = [\mathbf{x}_{j}, ..., \mathbf{x}_{j+T-1}]\) - Future Input \(\mathbf{X}^{(j)}_{fut} = [\mathbf{z}_{j}, ..., \mathbf{z}_{j+T+H-1}]\) - Static Input \(\mathbf{X}^{(j)}_{stat}\) (constant for the sequence) - Target \(\mathbf{Y}^{(j)} = [y_{j+T}, ..., y_{j+T+H-1}]^T\) where \(T\) is time_steps and \(H\) is forecast_horizons.

Examples

>>> import pandas as pd
>>> import numpy as np
>>> from fusionlab.nn.utils import reshape_xtft_data
>>> # Example 1: Basic usage without spatial grouping
>>> n_points = 50
>>> df1 = pd.DataFrame({
...     'Date': pd.to_datetime(pd.date_range('2023-01-01', periods=n_points)),
...     'Target': np.arange(n_points),
...     'Dynamic1': np.random.rand(n_points),
...     'Static1_val': np.random.rand(n_points) * 10, # Will be sequence-static
...     'Future1': np.random.rand(n_points) + 5
... })
>>> s, d, f, t = reshape_xtft_data(
...     df1, dt_col='Date', target_col='Target',
...     dynamic_cols=['Dynamic1'], static_cols=['Static1_val'],
...     future_cols=['Future1'], time_steps=5, forecast_horizons=3,
...     verbose=0
... )
>>> print(f"Example 1 Shapes: S={s.shape}, D={d.shape}, F={f.shape}, T={t.shape}")
Example 1 Shapes: S=(43, 1), D=(43, 5, 1), F=(43, 8, 1), T=(43, 3, 1)
>>> # Example 2: With spatial grouping
>>> df_list = []
>>> for group_id in ['A', 'B']:
...     group_df = pd.DataFrame({
...         'Date': pd.to_datetime(pd.date_range('2023-01-01', periods=30)),
...         'Target': np.random.rand(30) + (10 if group_id == 'A' else 20),
...         'Dynamic1': np.random.rand(30),
...         'Static_Group': 100 if group_id == 'A' else 200, # Truly static per group
...         'Future1': np.random.rand(30) + 5,
...         'GroupID': group_id
...     })
...     df_list.append(group_df)
>>> df2 = pd.concat(df_list)
>>> s, d, f, t = reshape_xtft_data(
...     df2, dt_col='Date', target_col='Target',
...     dynamic_cols=['Dynamic1'], static_cols=['Static_Group'],
...     future_cols=['Future1'], spatial_cols=['GroupID'],
...     time_steps=6, forecast_horizons=4, verbose=0
... )
>>> print(f"\nExample 2 Shapes (Spatial): S={s.shape}, D={d.shape}, F={f.shape}, T={t.shape}")
Example 2 Shapes (Spatial): S=(42, 1), D=(42, 6, 1), F=(42, 10, 1), T=(42, 4, 1)

See also

fusionlab.utils.ts_utils.ts_validator

Validates and converts datetime columns.

fusionlab.core.handlers.columns_manager

Formats and validates column lists.

fusionlab.core.checks.exist_features

Checks for column existence.

fusionlab.utils.io_utils.save_job

Utility for saving processed data.

fusionlab.nn.utils.create_sequences

Simpler sequence creation for basic models.

References