Source code for fusionlab.utils.forecast_utils

# -*- coding: utf-8 -*-
# Author: LKouadio <etanoyau@gmail.com>
# License: BSD-3-Clause

"""
Forecast utilities.
"""
from __future__ import annotations
import os 
import re
import logging 
from collections.abc import Mapping, Sequence
from typing import ( 
    Dict, 
    Iterable, 
    List, 
    Union, 
    Any, 
    Optional,
    Tuple , 
    Callable
)
import pandas as pd

from .._fusionlog import fusionlog 
from ..core.handlers import columns_manager 
from ..core.checks import check_spatial_columns, check_empty  
from ..core.io import is_data_readable  
from ..decorators import isdf 

from .generic_utils import vlog 
from .validator import is_frame 

logger = fusionlog().get_fusionlab_logger(__name__)

__all__= [ 
     'detect_forecast_type',
     'format_forecast_dataframe',
     'get_value_prefixes',
     'get_value_prefixes_in',
     'pivot_forecast_dataframe', 
     'get_step_names'
     ]

_DIGIT_RE = re.compile(r"\d+")


def get_step_names(
    forecast_steps: Iterable[int],
    step_names:Optional[
        Union[ Mapping[Any, str], Sequence[str], None]] = None,
    default_name: str = "",
) -> Dict[int, str]:
    
    r"""
    Build a *step → label* mapping for multi‑horizon plots.

    The helper reconciles an integer list ``forecast_steps`` with an
    optional *alias* container (*dict* or *sequence*) and returns a
    dictionary whose keys are the integer steps and whose values are
    human‑readable labels.

    Matching is **case‑insensitive** and tolerant to common
    delimiters—e.g. ``"Step 1"``, ``"step‑1"``, or ``"forecast step 1"``
    will all map to integer step ``1``.

    Parameters
    ----------
    forecast_steps : Iterable[int]
        Ordered steps, e.g. ``[1, 2, 3]``.
    step_names : dict | list | tuple | None, default=None
        Custom labels.  Accepted forms

        * **dict** – keys may be ``int`` or *any* string
          representation of the step.
        * **sequence** – positional, where the *k*‑th element labels
          step ``k+1``.
        * **None** – no custom mapping.
    default_name : str, default=""
        Fallback label for steps missing from *step_names*.  If
        empty, the step number itself is used (as a string).

    Returns
    -------
    dict[int, str]
        Mapping ``{step : label}`` for every element of
        *forecast_steps*.

    Notes
    -----
    * Dictionary keys are normalised with
      ``int(re.sub(r"[^0-9]", "", str(key)))`` before matching.
    * Duplicate keys in *step_names* are resolved by **last‐one wins**
      semantics.

    Examples
    --------
    >>> from fusionlab.utils.forecast_utils import get_step_names
    >>> get_step_names(
    ...     forecast_steps=[1, 2, 3],
    ...     step_names={"1": "Year 2021", 2: "2022", "step 3": "2023"},
    ... )
    {1: 'Year 2021', 2: '2022', 3: '2023'}

    >>> get_step_names(
    ...     forecast_steps=[1, 2, 3, 4],
    ...     step_names={"1": "2021", "2": "2022"},
    ... )
    {1: '2021', 2: '2022', 3: '3', 4: '4'}

    >>> get_step_names(
    ...     [1, 2, 3, 4],
    ...     step_names=None,
    ...     default_name="step with no name",
    ... )
    {1: 'step with no name', 2: 'step with no name',
     3: 'step with no name', 4: 'step with no name'}

    See Also
    --------
    fusionlab.utils.data_utils.widen_temporal_columns :
        Converts long format to wide; often used with
        *forecast_steps* when plotting.
    """
    # Ensure we have a concrete list to preserve order and allow
    # multiple passes.
    forecast_steps = columns_manager( forecast_steps , empty_as_none= False)
    steps: List[int] = [int(s) for s in forecast_steps]
    lookup: Dict[int, str] = {}

    if step_names is None:
        pass # remain empty
    elif isinstance(step_names, Mapping):
        for k, v in step_names.items():
            idx = _to_int_key(k) 
            # Skip keys that cannot be coerced to int (e.g. None, dict)
            if idx is not None:
                lookup[idx] = str(v)
    elif isinstance(step_names, Sequence) and not isinstance(
            step_names, (str, bytes)):
        for idx, v in enumerate(step_names, start=1):
            lookup[idx] = str(v)
    else:
        raise TypeError(
            "`step_names` must be a mapping, a sequence, or None "
            f"(got {type(step_names).__name__})."
        )

    # Build the final mapping, applying defaults where necessary.
    result: Dict[int, str] = {}
    for step in steps:
        if step in lookup:
            result[step] = lookup[step]
        elif default_name:
            result[step] = default_name
        else:
            result[step] = str(step)
    return result

def _to_int_key(key: Any) -> int | None:
    """Try to coerce a mapping key to int by stripping non‑digits."""
    if isinstance(key, int):
        return key
    digits = "".join(_DIGIT_RE.findall(str(key)))
    return int(digits) if digits else None


@isdf
def format_forecast_dataframe(
    df: pd.DataFrame,
    to_wide: bool = True,
    time_col: str = 'coord_t',
    spatial_cols: Tuple[str]=('coord_x', 'coord_y'), 
    value_prefixes: Optional[List[str]] = None,
    _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None,
    **pivot_kwargs
) -> Union[pd.DataFrame, str]:
    """Auto-detects DataFrame format and conditionally pivots to wide format.

    This function serves as a smart wrapper. It first determines if the
    input DataFrame is in a 'long' or 'wide' forecast format based on
    its column structure. If `to_wide` is True and the format is
    'long', it calls :func:`pivot_forecast_dataframe` to perform the
    transformation.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame to check and potentially transform.
    to_wide : bool, default True
        - If ``True``, the function's goal is to return a wide-format
          DataFrame. It will pivot a long-format frame or return a
          wide-format frame as is.
        - If ``False``, the function only performs detection and
          returns a string ('wide', 'long', or 'unknown').
    time_col : str, default 'coord_t'
        The name of the column that indicates the time step. Its
        presence is a primary indicator of a long-format DataFrame.
    value_prefixes : list of str, optional
        A list of prefixes for the value columns (e.g., ['subsidence',
        'GWL']). If ``None``, the function will attempt to infer them
        from column names that do not match common ID columns.
    **pivot_kwargs
        Additional keyword arguments to pass down to the
        :func:`pivot_forecast_dataframe` function if it is called.
        Common arguments include `id_vars`, `static_actuals_cols`,
        `verbose`, etc.

    Returns
    -------
    pd.DataFrame or str
        - If `to_wide` is ``True``, returns the (potentially pivoted)
          wide-format ``pd.DataFrame``.
        - If `to_wide` is ``False``, returns a string: 'wide', 'long',
          or 'unknown'.

    See Also
    --------
    pivot_forecast_dataframe : The underlying function that performs
                               the pivot operation.

    Examples
    --------
    >>> # df_long is a typical long-format forecast output
    >>> df_long.columns
    Index(['sample_idx', 'forecast_step', 'coord_t', 'coord_x', ...])
    >>> # Detect format
    >>> format_str = format_forecast_dataframe(df_long, to_wide=False)
    >>> print(format_str)
    'long'
    >>>
    >>> # Convert to wide format
    >>> df_wide = format_forecast_dataframe(
    ...     df_long,
    ...     to_wide=True,
    ...     id_vars=['sample_idx', 'coord_x', 'coord_y'],
    ...     value_prefixes=['subsidence', 'GWL'],
    ...     static_actuals_cols=['subsidence_actual']
    ... )
    >>> # print(df_wide.columns)
    # Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual',
    #        'GWL_2018_q50', ...], dtype='object')
    """
    # --- Format Detection Logic ---
    # Heuristic 1: If time_col exists, it's very likely 'long' format.
    is_long_format = time_col in df.columns

    # Heuristic 2: Check for wide-format columns like 'prefix_YYYY_suffix'
    # Use a regex to look for (prefix)_(4-digit year)_...
    _spatial_cols = columns_manager(spatial_cols, empty_as_none= False)
    
    if spatial_cols: 
        spatial_cols = columns_manager (spatial_cols)
        check_spatial_columns(df, spatial_cols=spatial_cols)
        # coord_x, coord_y = spatial_cols 
        
    if value_prefixes is None:
        # Auto-infer prefixes if not provided
        # Exclude common ID columns
        non_value_cols = {
            'sample_idx', *_spatial_cols,  'forecast_step', time_col
            }
        value_prefixes = sorted(list(set(
            [c.split('_')[0] for c in df.columns
             if c not in non_value_cols]
        )))

    wide_col_pattern = re.compile(
        r'(' + '|'.join(re.escape(p) for p in value_prefixes) + 
        r')_(\d{4})_?.*'
    )
    has_wide_columns = any(wide_col_pattern.match(col) for col in df.columns)

    detected_format = 'unknown'
    if is_long_format:
        detected_format = 'long'
    elif has_wide_columns:
        detected_format = 'wide'
    
    verbose = pivot_kwargs.get('verbose', 0)
    vlog(f"Auto-detected DataFrame format: '{detected_format}'",
         level=1, verbose=verbose, logger = _logger 
         )

    # --- Action based on mode ---
    if to_wide:
        if detected_format == 'long':
            vlog("`to_wide` is True and format is 'long'. "
                 "Pivoting DataFrame...", level=1, verbose=verbose, 
                 logger = _logger )
            # Pass necessary args to the pivot function
            pivot_args = {
                'time_col': time_col,
                'value_prefixes': value_prefixes,
                **pivot_kwargs # Pass through other args
            }
            if 'id_vars' not in pivot_args:
                # Provide a sensible default for id_vars if not given
                pivot_args['id_vars'] = [
                    c for c in ['sample_idx', *_spatial_cols]
                    if c in df.columns
                ]
                vlog(f"Using default id_vars: {pivot_args['id_vars']}",
                     level=2, verbose=verbose, logger = _logger 
                     )

            return pivot_forecast_dataframe(df.copy(), **pivot_args)
        
        elif detected_format == 'wide':
            vlog("`to_wide` is True but DataFrame is already in wide "
                 "format. Returning as is.", level=1, verbose=verbose, 
                 logger = _logger 
                 )
            return df
        else: # 'unknown'
            vlog("Warning: DataFrame format is 'unknown'. "
                 "Cannot pivot. Returning original DataFrame.",
                 level=1, verbose=verbose, logger = _logger 
                 )
            return df
    else: # to_wide is False
        return detected_format

@isdf
def get_value_prefixes(
    df: pd.DataFrame,
    exclude_cols: Optional[List[str]] = None, 
    spatial_cols: Tuple[str, str] = ('coord_x', 'coord_y'), 
    time_col: str ='coord_t'
) -> List[str]:
    """
    Automatically detects the prefixes of value columns from a DataFrame.

    This utility inspects the column names to infer the base names of
    the metrics being forecasted (e.g., 'subsidence', 'GWL'),
    excluding common ID and coordinate columns. It works with both
    long and wide format forecast DataFrames.

    Parameters
    ----------
    df : pd.DataFrame
        The DataFrame from which to detect value prefixes.
    exclude_cols : list of str, optional
        A list of columns to explicitly ignore during detection. If
        None, a default list of common ID/coordinate columns is
        used (e.g., 'sample_idx', 'coord_x', 'coord_t', etc.).

    Returns
    -------
    list of str
        A sorted list of unique prefixes found in the column names.

    Examples
    --------
    >>> from fusionlab.utils.data_utils import get_values_prefixes
    >>> # For a long-format DataFrame
    >>> long_cols = ['sample_idx', 'coord_t', 'subsidence_q50', 'GWL_q50']
    >>> df_long = pd.DataFrame(columns=long_cols)
    >>> get_value_prefixes(df_long)
    ['GWL', 'subsidence']

    >>> # For a wide-format DataFrame
    >>> wide_cols = ['sample_idx', 'coord_x', 'subsidence_2022_q90', 'GWL_2022_q50']
    >>> df_wide = pd.DataFrame(columns=wide_cols)
    >>> get_value_prefixes(df_wide)
    ['GWL', 'subsidence']
    """
    if exclude_cols is None:
        # Default set of columns that are not value columns
        exclude_cols = {
            'sample_idx', 'forecast_step', time_col,
            *spatial_cols
        }
    else:
        exclude_cols = set(exclude_cols)

    prefixes = set()
    for col in df.columns:
        if col in exclude_cols:
            continue
        # The prefix is assumed to be the part before the first underscore
        prefix = col.split('_')[0]
        prefixes.add(prefix)
    
    return sorted(list(prefixes))

[docs] @check_empty(['data']) @is_data_readable def pivot_forecast_dataframe( data: pd.DataFrame, id_vars: List[str], time_col: str, value_prefixes: List[str], static_actuals_cols: Optional[List[str]] = None, time_col_is_float_year: Union[bool, str] = 'auto', round_time_col: bool = False, verbose: int = 0, savefile: Optional[str] = None, _logger: Optional[Union[logging.Logger, Callable[[str], None]]] = None, **kws ) -> pd.DataFrame: """Transforms a long-format forecast DataFrame to a wide format. This utility reshapes time series prediction data from a "long" format, where each row represents a single time step for a given sample, to a "wide" format, where each row represents a single sample and columns correspond to values at different time steps. Parameters ---------- data : pd.DataFrame The input long-format DataFrame. It must contain the columns specified in `id_vars` and `time_col`, as well as value columns that start with the strings in `value_prefixes`. id_vars : list of str A list of column names that uniquely identify each sample or group. These columns will be preserved in the wide-format output. For example: ``['sample_idx', 'coord_x', 'coord_y']``. time_col : str The name of the column that represents the time step or year of the forecast (e.g., 'coord_t' or 'forecast_step'). This column's values will become part of the new column names. value_prefixes : list of str A list of prefixes for the value columns that need to be pivoted. The function identifies columns starting with these prefixes. For instance, ``['subsidence', 'GWL']`` would match 'subsidence_q10', 'GWL_q50', etc. static_actuals_cols : list of str, optional A list of columns containing static "actual" or ground truth values for each sample. These values are assumed to be constant for each unique `sample_idx` and are merged back into the wide DataFrame after pivoting. Example: ``['subsidence_actual']``. time_col_is_float_year : bool or 'auto', default 'auto' Controls how the `time_col` values are formatted into new column names. - If ``'auto'``, automatically detects if `time_col` has a float dtype. - If ``True``, treats `time_col` values (e.g., 2018.0) as years and converts them to integer strings ('2018'). - If ``False``, uses the string representation of the value as is. round_time_col : bool, default False If ``True`` and `time_col` is a float type, its values will be rounded to the nearest integer before being used in column names. This is useful for cleaning up float years (e.g., 2018.0001 -> 2018). verbose : int, default 0 Controls the verbosity of logging messages. `0` is silent. Higher values print more details about the process. savefile : str, optional If a file path is provided, the final wide-format DataFrame will be saved as a CSV file to that location. Returns ------- pd.DataFrame A wide-format DataFrame with one row per unique combination of `id_vars`. New columns are created in the format `{prefix}_{time_str}{_suffix}` (e.g., 'subsidence_2018_q10'). See Also -------- pandas.pivot_table : The core function used for reshaping data. pandas.merge : Used to re-join static columns after pivoting. Notes ----- - The combination of columns in `id_vars` and `time_col` must uniquely identify each row in `df_long` for the pivot to succeed without data loss. - If using `static_actuals_cols`, the `id_vars` list must contain 'sample_idx' to correctly merge the static data back. Examples -------- >>> import pandas as pd >>> from fusionlab.utils.data_utils import pivot_forecast_dataframe >>> data = { ... 'sample_idx': [0, 0, 1, 1], ... 'coord_t': [2018.0, 2019.0, 2018.0, 2019.0], ... 'coord_x': [0.1, 0.1, 0.5, 0.5], ... 'coord_y': [0.2, 0.2, 0.6, 0.6], ... 'subsidence_q50': [-8, -9, -13, -14], ... 'subsidence_actual': [-8.5, -8.5, -13.2, -13.2], ... 'GWL_q50': [1.2, 1.3, 2.2, 2.3], ... } >>> df_long_example = pd.DataFrame(data) >>> df_wide = pivot_forecast_dataframe( ... data=df_long_example, ... id_vars=['sample_idx', 'coord_x', 'coord_y'], ... time_col='coord_t', ... value_prefixes=['subsidence', 'GWL'], ... static_actuals_cols=['subsidence_actual'], ... verbose=0 ... ) >>> print(df_wide.columns) Index(['sample_idx', 'coord_x', 'coord_y', 'subsidence_actual', 'GWL_2018_q50', 'GWL_2019_q50', 'subsidence_2018_q50', 'subsidence_2019_q50'], dtype='object') """ is_frame(data, df_only= True) df_processed = data.copy() vlog(f"Starting pivot operation. Initial shape: {df_processed.shape}", level=2, verbose=verbose, logger = _logger ) if not isinstance(df_processed, pd.DataFrame): raise TypeError( "`df_long` must be a pandas DataFrame." ) value_cols_to_pivot = [ col for col in df_processed.columns if any(col.startswith(prefix) for prefix in value_prefixes) ] required_cols = id_vars + [time_col] + value_cols_to_pivot missing_cols = [ col for col in required_cols if col not in df_processed.columns ] if missing_cols: raise ValueError( f"Missing required columns in DataFrame: {missing_cols}" ) # Determine if the time column should be treated as a float year is_float_year = False if time_col_is_float_year == 'auto': if pd.api.types.is_float_dtype(df_processed[time_col]): is_float_year = True vlog(f"'{time_col}' auto-detected as float year.", level=2, verbose=verbose, logger = _logger ) elif time_col_is_float_year is True: is_float_year = True # Round the time column before pivoting if requested if round_time_col and is_float_year: vlog(f"Rounding time column '{time_col}'.", level=1, verbose=verbose, logger = _logger ) df_processed[time_col] = df_processed[time_col].round().astype(int) # After rounding, it's no longer a float year is_float_year = False elif round_time_col and not is_float_year: vlog(f"Warning: `round_time_col` is True but '{time_col}' " "is not a float year. Skipping rounding.", level=1, verbose=verbose, logger = _logger ) static_df = None if static_actuals_cols: if 'sample_idx' not in df_processed.columns: raise ValueError( "'sample_idx' must be in df_long to handle " "static_actuals_cols." ) vlog(f"Extracting static columns: {static_actuals_cols}", level=2, verbose=verbose, logger = _logger ) static_df = df_processed[ ['sample_idx'] + static_actuals_cols ].drop_duplicates( subset=['sample_idx'] ).set_index('sample_idx') pivot_index = list(set(id_vars) & set(df_processed.columns)) pivot_columns = [time_col] pivot_values = value_cols_to_pivot vlog(f"Pivoting data with index={pivot_index}, " f"columns='{time_col}'.", level=2, verbose=verbose, logger = _logger ) try: df_pivoted = df_processed.pivot_table( index=pivot_index, columns=pivot_columns, values=pivot_values, aggfunc='first' ) except Exception as e: raise RuntimeError( "Pandas pivot_table failed. Check if `id_vars` and " f"`time_col` uniquely identify rows. Error: {e}" ) vlog("Flattening pivoted column names.", level=2, verbose=verbose, logger = _logger ) new_columns = [] for value_col, time_val in df_pivoted.columns: parts = value_col.split('_', 1) prefix = parts[0] suffix = f"_{parts[1]}" if len(parts) > 1 else "" time_str = str(time_val) if is_float_year: try: time_str = str(int(time_val)) except (ValueError, TypeError): pass new_col_name = f"{prefix}_{time_str}{suffix}" new_columns.append(new_col_name) df_pivoted.columns = new_columns df_pivoted = df_pivoted.reset_index() if static_df is not None: vlog("Merging static columns back into the wide DataFrame.", level=2, verbose=verbose, logger = _logger ) df_wide = pd.merge( df_pivoted, static_df, on='sample_idx', how='left' ) cols_order = id_vars + static_actuals_cols + [ c for c in df_wide.columns if c not in id_vars + static_actuals_cols ] df_wide = df_wide[cols_order] else: df_wide = df_pivoted vlog(f"Pivot complete. Final shape: {df_wide.shape}", level=1, verbose=verbose) if savefile: try: vlog(f"Saving DataFrame to '{savefile}'.", level=1, verbose=verbose) save_dir = os.path.dirname(savefile) if save_dir and not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) df_wide.to_csv(savefile, index=False) vlog("Save successful.", level=2, verbose=verbose) except Exception as e: logger.error(f"Failed to save file to '{savefile}': {e}") return df_wide
@isdf def get_value_prefixes_in( df: pd.DataFrame, exclude_cols: Optional[List[str]] = None ) -> List[str]: """ Automatically detects the prefixes of value columns from a DataFrame. (This is a dependency for the function below) """ if exclude_cols is None: exclude_cols = { 'sample_idx', 'forecast_step', 'coord_t', 'coord_x', 'coord_y' } else: exclude_cols = set(exclude_cols) prefixes = set() for col in df.columns: if col in exclude_cols: continue # The prefix is assumed to be the # part before the first underscore prefix = col.split('_')[0] prefixes.add(prefix) return sorted(list(prefixes)) @isdf def detect_forecast_type( df: pd.DataFrame, value_prefixes: Optional[List[str]] = None, ) -> str: """ Auto-detects whether a DataFrame contains deterministic or quantile forecasts, supporting both long and wide formats. This utility inspects column names to determine the nature of the predictions. - It identifies a 'quantile' forecast if it finds columns containing a ``_qXX`` pattern (e.g., 'subsidence_q10', 'GWL_2022_q50'). - It identifies a 'deterministic' forecast if no quantile columns are found, but columns ending in ``_pred``, `_actual`, or matching a base prefix exist (e.g., 'subsidence_pred', 'subsidence_2022_actual', 'GWL'). Parameters ---------- df : pd.DataFrame The DataFrame to inspect. value_prefixes : list of str, optional A list of value prefixes (e.g., ['subsidence', 'GWL']) to focus the search on. If ``None``, prefixes are inferred from column names. Returns ------- str One of 'quantile', 'deterministic', or 'unknown'. Examples -------- >>> import pandas as pd >>> from fusionlab.utils.forecast_utils import detect_forecast_type >>> # Long format quantile >>> df_quant_long = pd.DataFrame(columns=['subsidence_q50', 'GWL_q90']) >>> detect_forecast_type(df_quant_long) 'quantile' >>> # Wide format quantile >>> df_quant_wide = pd.DataFrame(columns=['subsidence_2022_q50']) >>> detect_forecast_type(df_quant_wide) 'quantile' >>> # Deterministic forecast >>> df_determ = pd.DataFrame(columns=['subsidence_pred', 'GWL']) >>> detect_forecast_type(df_determ) 'deterministic' """ if value_prefixes is None: # Auto-detect prefixes if not provided. value_prefixes = get_value_prefixes(df) if not value_prefixes: return 'unknown' # Updated regex to handle both long (_q50) and wide (_YYYY_q50) formats # It looks for '_q' followed by one or more digits anywhere in the string. quantile_pattern = re.compile(r'_q\d+') # Regex to find deterministic suffixes like _pred or _actual at the end pred_pattern = re.compile(r'_(pred|actual)$') has_quantile = False has_pred_or_actual = False has_bare_prefix = False for col in df.columns: if quantile_pattern.search(col): # Check if it belongs to one of our prefixes for prefix in value_prefixes: if col.startswith(prefix): has_quantile = True break if has_quantile: break if has_quantile: return 'quantile' # If no quantiles, check for deterministic patterns for col in df.columns: for prefix in value_prefixes: # Check for exact prefix match (e.g., 'subsidence') if col == prefix: has_bare_prefix = True # Check for patterns like 'subsidence_pred', 'GWL_2022_actual' elif col.startswith(prefix) and pred_pattern.search(col): has_pred_or_actual = True if has_pred_or_actual or has_bare_prefix: break if has_pred_or_actual or has_bare_prefix: return 'deterministic' # If neither format is detected, return 'unknown'. return 'unknown'