Source code for fusionlab.params

# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>

"""
Module ``fusionlab.params`` provides small, self-documenting helpers
for scalar physical hyperparameters used in PINNs and physics-guided
nets.

Two families of descriptors are provided:

- Learnable scalars (subclasses of :class:`BaseLearnable`) such as
  :class:`LearnableK`, :class:`LearnableSs`, :class:`LearnableMV`,
  :class:`LearnableKappa`, etc.  These wrap a trainable scalar, often
  stored in log-space to enforce positivity.

- Fixed scalars (subclasses of :class:`BaseFixed`) such as
  :class:`FixedGammaW` and :class:`FixedHRef`, used for constants in
  the physics block.

Legacy descriptors :class:`LearnableC`, :class:`FixedC` and
:class:`DisabledC` are kept for backwards-compatible handling of a
generic positive coefficient :math:`C`.

In the GeoPriorSubsNet consolidation model specifically, the spatial
fields :math:`K(x,y)`, :math:`S_s(x,y)`, :math:`H(x,y)` and the
relaxation time :math:`\\tau(x,y)` are represented as *effective
fields* built from covariates and the neural network, whereas the
scalar wrappers here represent global hyperparameters such as the
effective compressibility :math:`m_v`, the consistency factor
:math:`\\bar\\kappa`, the unit weight of water :math:`\\gamma_w`
and the reference head :math:`h_{\\mathrm{ref}}`.
"""

from __future__ import annotations
import importlib
from typing import Any, Union, Optional, Dict, Type, Literal
from abc import ABC, abstractmethod

# Attempt to import TensorFlow, else fall
# back to NumPy
_tf_spec = importlib.util.find_spec(
    "tensorflow"
)
if _tf_spec is not None:
    import tensorflow as tf
    _BACKEND = "tensorflow"
    Tensor = tf.Tensor
    Variable = tf.Variable
else:
    import numpy as np

    _BACKEND = "numpy"
    class _DummyTF:
        pass

    class tf:
        Tensor   = _DummyTF
        Variable = _DummyTF
    # Fallback types for type hinting
    Tensor = Any
    Variable = Any


# Keras serialisable base-class
if _BACKEND =='tensorflow': 
    from tensorflow.keras.saving import register_keras_serializable
else:         # TF missing → no serialisation
    def register_keras_serializable(*_a, **_kw):            # type: ignore
        def decorator(cls):                                 # pragma: no cover
            return cls
        return decorator
    

__all__ = [
    "LearnableC", "FixedC", "DisabledC", 
    "LearnableK", "LearnableSs", "LearnableQ",
    # --- New Parameters for Revised Manuscript ---
    "LearnableMV", "LearnableKappa", "FixedGammaW", "FixedHRef"
]


@register_keras_serializable("fusionlab.params", name="_BaseC")
class _BaseC(ABC):
    r"""
    Parent class for :math:`C` descriptors.

    Each subclass provides :pyattr:`value`
    (``float`` in NumPy mode, ``tf.Variable`` in TF mode)
    and declares whether it is *trainable*.

    The class supports Keras JSON round-trip via
    :py:meth:`get_config` / :py:meth:`from_config`.
    """

    trainable: bool = False      #: overridden by concrete classes

    
    def __init__(self, **kwargs: Any):
        self.value = self._make_value(**kwargs)

    # Keras (de)serialisation 
    def get_config(self) -> Dict[str, Any]:
        cfg: Dict[str, Any] = dict(self._export_kw)          # type: ignore
        cfg["class_name"] = self.__class__.__name__
        return cfg

    @classmethod
    def from_config(cls: Type["_BaseC"], cfg: Dict[str, Any]) -> "_BaseC":
        cfg = dict(cfg)
        cfg.pop("class_name", None)
        return cls(**cfg)

    #  utilities -
    def __repr__(self) -> str:                               # noqa: D401
        nm = self.__class__.__name__
        return f"<{nm} trainable={self.trainable}, value={self.value!r}>"

    # - Implemented by subclasses -
    @abstractmethod
    def _make_value(self, **kwargs: Any) -> Any:             # noqa: D401
        ...

[docs] @register_keras_serializable("fusionlab.params", name="LearnableC") class LearnableC(_BaseC): r""" Indicates that the PINN’s physical coefficient :math:`C` should be learned (trainable). We actually learn :math:`\log(C)` to ensure :math:`C > 0`. The user supplies an `initial_value`, and the model initializes: Trainable :math:`C`. In TF mode we keep :math:`\log C` as a :class:`tf.Variable`, ensuring :math:`C>0`. In NumPy mode the coefficient *cannot be trained*, so it degrades gracefully to a fixed float. .. math:: \log C \;=\; \log(\text{initial\_value}). Parameters ---------- initial_value : float Strictly positive initial :math:`C`. Attributes ---------- initial_value : float The positive starting value for :math:`C`. Must be strictly positive. Examples -------- >>> from fusionlab.params import LearnableC >>> # Learn C, starting from C = 0.01 >>> pinn_coeff = LearnableC(initial_value=0.01) >>> # Learn C, starting from C = 0.001 >>> pinn_coeff_small = LearnableC(initial_value=0.001) """
[docs] def __init__(self, initial_value: float = 0.01, **kwargs ): super().__init__( initial_value=initial_value, **kwargs )
def _make_value(self, initial_value: float = 0.01) -> Any: if not isinstance(initial_value, (float, int)): raise TypeError( f"LearnableC.initial_value must be a float, got " f"{type(initial_value).__name__}" ) if initial_value <= 0: raise ValueError( "LearnableC.initial_value must be strictly positive." ) self.initial_value = float(initial_value) self._export_kw = {"initial_value": self.initial_value} # type: ignore if _BACKEND == "tensorflow": self.trainable = True log_c0 = tf.math.log(tf.constant(float(initial_value), tf.float32)) return tf.Variable(log_c0, dtype=tf.float32, name="log_pinn_coefficient_C") # NumPy branch --> behave as a *fixed* coefficient self.trainable = False return float(initial_value)
[docs] @register_keras_serializable("fusionlab.params", name="FixedC") class FixedC(_BaseC): r""" Non-trainable, constant :math:`C`. Indicates that the PINN's physical coefficient :math:`C` should be held fixed (non-trainable) at a specified `value`. .. math:: C = \text{value}, \qquad \text{non-trainable}. Parameters ---------- value : float Constant :math:`C \ge 0`. Attributes ---------- value : float The non-negative, constant value of :math:`C`. Examples -------- >>> from fusionlab.params import FixedC >>> # Use a fixed C = 0.5 >>> pinn_coeff = FixedC(value=0.5) """
[docs] def __init__(self, value: float, **kwargs): super().__init__(value = value, **kwargs)
def _make_value(self, value: float) -> float: if not isinstance(value, (float, int)): raise TypeError( f"FixedC.value must be a float, got {type(value).__name__}" ) if value < 0: raise ValueError( "FixedC.value must be non-negative." ) self._value = float(value) self._export_kw = {"value": self._value} # type: ignore return float(value)
[docs] @register_keras_serializable("fusionlab.params", name="DisabledC") class DisabledC(_BaseC): r""" Disable physics – :math:`C` is ignored. Indicates that physics should be disabled. In practice, :math:`C` is irrelevant (defaults to 1.0 internally, but is never used if `lambda_pde == 0` when compiling). Attributes ---------- None Examples -------- >>> from fusionlab.params import DisabledC >>> pinn_coeff = DisabledC() """
[docs] def __init__(self): # No parameters needed. Presence of this class signals “disable”. super().__init__()
def _make_value(self) -> float: # noqa: D401 self._export_kw = {} # type: ignore return 1.0 # convention, but unused when physics is disabled
@register_keras_serializable("fusionlab.params", name ="BaseLearnable") class BaseLearnable(ABC): """ Abstract base for learnable physical parameters. Parameters ---------- initial_value : float Initial numeric value for the parameter. name : str Unique identifier for the variable. log_transform : bool, optional If True, store in log-space for positivity constraint, by default False. trainable : bool, optional If True, make variable trainable, by default True. Attributes ---------- initial_value : float The original provided value. name : str Variable name in the computation graph. log_transform : bool Whether to apply log transform. trainable : bool Trainable flag for optimization. Examples -------- >>> param = LearnableK(initial_value=0.5) >>> value = param.get_value() """ def __init__( self, initial_value: float, name: str, log_transform: bool = False, trainable: bool = True, **kws # for future extension ): if not isinstance( initial_value, (float, int) ): raise TypeError( f"Initial value for {self.__class__.__name__} " f"must be a float, got {type(initial_value).__name__}" ) if log_transform and initial_value <= 0: raise ValueError( f"{self.__class__.__name__} initial value must be " "strictly positive for log transform." ) self.initial_value = float(initial_value) self.name = name self.log_transform = log_transform self.trainable = trainable self._variable = self._create_variable() def _create_variable(self) -> Union[Variable, Tensor, float]: """ Internal: create tf.Variable or fallback value. Returns ------- Union[Variable, Tensor, float] Configured variable or numeric. """ if _BACKEND == "tensorflow": value = self.initial_value if self.log_transform: value = tf.math.log(value) return tf.Variable( initial_value=tf.cast( value, dtype=tf.float32 ), trainable=self.trainable, name=self.name ) return ( np.log(self.initial_value) if self.log_transform else self.initial_value ) @abstractmethod def get_value( self ) -> Union[Tensor, float]: """ Retrieve parameter value. Returns ------- Union[Tensor, float] Transformed parameter, e.g., :math:`\exp(log\_param)` if log_transform is True. """ pass def get_config(self) -> Dict[str, Any]: """ Return a JSON-serialisable dict for tf.keras. Notes ----- Keras looks for this method during ``model.save()`` and ``keras.saving.serialization_lib.serialize_keras_object``. """ return { "initial_value": self.initial_value, "name": self.name, "log_transform": self.log_transform, "trainable": self.trainable, # we also store the concrete subclass path for clarity "__class_name__": self.__class__.__name__, } @classmethod def from_config(cls, config: Dict[str, Any]) -> "BaseLearnable": """ Re-instantiate from :py:meth:`get_config`. Keras passes *config* exactly as returned above. """ # Guard against stray keys Keras might inject kwargs = { k: v for k, v in config.items() if k in {"initial_value", "name", "log_transform", "trainable"} } return cls(**kwargs) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(initial_value=" f"{self.initial_value}, trainable={self.trainable}, " f"name={self.name})" )
[docs] @register_keras_serializable("fusionlab.params", name ="LearnableK") class LearnableK(BaseLearnable): """ Learnable Hydraulic Conductivity (K). Indicates that the PINN’s hydraulic conductivity :math:`K` should be learned (trainable) if TensorFlow is available; otherwise behaves as a fixed NumPy‐based parameter. We learn :math:`\log(K)` to ensure :math:`K > 0`. The user supplies an `initial_value`, and the object initializes: .. math:: \log K \;=\; \log(\text{initial\_value}). Ensures positivity via log-space. See Also -------- BaseLearnableParam Examples -------- >>> k = LearnableK(1.2) >>> :math:`K = k.get_value()` """
[docs] def __init__( self, initial_value: float = 1.0, log_transform: bool=True, name: Optional[str] =None, trainable: bool=True, **kws ): super().__init__( initial_value=initial_value, log_transform=log_transform, name= name or "learnable_K", trainable= trainable, **kws )
[docs] def get_value( self ) -> Union[Tensor, float]: """ Return :math:`K = \exp(log\_K)`. Returns ------- Union[Tensor, float] Positive conductivity. """ if _BACKEND == "tensorflow": return tf.exp(self._variable) return float( __import__("numpy").exp( self._variable ) )
[docs] @register_keras_serializable("fusionlab.params", name ="LearnableSs") class LearnableSs(BaseLearnable): """ Learnable Specific Storage (Ss). Indicates that the PINN's specific storage coefficient :math:`S_s` should be learned (trainable) if TensorFlow is available; otherwise acts as a fixed NumPy‐based parameter. We learn :math:`\log(S_s)` to ensure :math:`S_s > 0`. The user supplies an `initial_value`, and the object initializes: .. math:: \log S_s \;=\; \log(\text{initial\_value}). Returns positive values via exp transform. Examples -------- >>> ss = LearnableSs(1e-3) >>> value = ss.get_value() """
[docs] def __init__( self, initial_value: float = 1e-4, log_transform: bool=True, name: Optional[str] =None, trainable: bool=True, **kws ): super().__init__( initial_value=initial_value, name= name or "learnable_Ss", log_transform=log_transform, trainable= trainable, **kws )
[docs] def get_value( self ) -> Union[Tensor, float]: """ Return :math:`Ss = \exp(log\_Ss)`. Returns ------- Union[Tensor, float] Positive storage coefficient. """ if _BACKEND == "tensorflow": return tf.exp(self._variable) return float( __import__("numpy").exp( self._variable ) )
[docs] @register_keras_serializable("fusionlab.params", name ="LearnableQ") class LearnableQ(BaseLearnable): """ Learnable Source/Sink term (Q). Indicates that the PINN's source/sink term :math:`Q` should be learned (trainable) if TensorFlow is available; otherwise acts as a fixed NumPy‐based parameter. Unlike K and Ss, Q may be positive or negative, so we learn it directly (no log‐transform). The user supplies an `initial_value`, and the object initializes: .. math:: Q \;=\; \text{initial\_value}. Unconstrained: may be positive or negative. Examples -------- >>> q = LearnableQ(0.0) >>> q.get_value() 0.0 """
[docs] def __init__( self, initial_value: float = 0.0, # log_transform: bool=False, # Q should not be log-transformed name: Optional[str] =None, trainable: bool=True, **kws ): super().__init__( initial_value=initial_value, name= name or "learnable_Q", log_transform=False, # Explicitly set to False trainable= trainable, **kws )
[docs] def get_value( self ) -> Union[Tensor, float]: """ Return raw :math:`Q` value. Returns ------- Union[Tensor, float] Source/sink strength. """ if _BACKEND == "tensorflow": return self._variable # No exp() return float(self._variable) # No exp()
@register_keras_serializable("fusionlab.params", name="LearnableMV") class LearnableMV(BaseLearnable): r""" Learnable effective vertical compressibility (m_v). In GeoPriorSubsNet this is a *global scalar* that links head changes to equilibrium settlement via :math:`s_{\\mathrm{eq}}(h) = m_v\\,\\gamma_w\\,\\Delta h\\,H`, where :math:`H(x,y)` is an effective compressible thickness field. The field :math:`S_s(x,y)` is interpreted as an effective specific storage, with :math:`S_s \\approx m_v\\,\\gamma_w` used as a soft consistency relation rather than a hard identity. Positivity is enforced by learning :math:`\\log(m_v)`. Parameters ---------- initial_value : float, default=1e-7 Initial value for :math:`m_v` [Pa^-1]. Must be positive and typically falls in a geotechnically plausible range (e.g. :math:`10^{-9}–10^{-5}` Pa^-1). name : str, optional Variable name. trainable : bool, default=True Whether the parameter is trainable. """ def __init__( self, initial_value: float = 1e-7, name: Optional[str] = None, trainable: bool = True, log_transform: bool=True, # m_v must be positive **kws ): super().__init__( initial_value=initial_value, name=name or "learnable_mv", log_transform=log_transform, trainable=trainable, **kws ) def get_value(self) -> Union[Tensor, float]: """ Return :math:`m_v = \exp(\log(m_v))` """ if _BACKEND == "tensorflow": return tf.exp(self._variable) return float(np.exp(self._variable)) @register_keras_serializable("fusionlab.params", name="LearnableKappa") class LearnableKappa(BaseLearnable): """ Learnable scalar consistency factor (:math:`\\bar{\\kappa}`). In the revised consolidation prior, :math:`\\bar{\\kappa}` relates the effective relaxation time :math:`\\tau(x,y)` to the Terzaghi-style diffusion time built from the effective fields :math:`K(x,y)`, :math:`S_s(x,y)` and :math:`H(x,y)`. In the manuscript, it collects terms such as drainage-path ratios and leakage / anisotropy factors. It enters a soft prior term of the form .. math:: \\log \\tau_{\\mathrm{prior}}(x,y) \\approx \\log\\left( \\frac{\\bar{\\kappa} H(x,y)^2} {\\pi^2 K(x,y) / S_s(x,y)} \\right), which is penalised against the learned :math:`\\log \\tau(x,y)`. Positivity is enforced via a log-space parametrisation. Parameters ---------- initial_value : float, default=1.0 Initial guess for :math:`\\bar{\\kappa}` name : str, optional Variable name. trainable : bool, default=True Whether the parameter is trainable. """ def __init__( self, initial_value: float = 1.0, name: Optional[str] = None, log_transform: bool=True, # kappa_bar must be positive trainable: bool = True, **kws ): super().__init__( initial_value=initial_value, name=name or "learnable_kappa", log_transform=log_transform, trainable=trainable, **kws ) def get_value(self) -> Union[Tensor, float]: """ Return :math:`\bar{\kappa} = \exp(\log(\bar{\kappa}))` """ if _BACKEND == "tensorflow": return tf.exp(self._variable) return float(np.exp(self._variable)) @register_keras_serializable("fusionlab.params", name="BaseFixed") class BaseFixed(ABC): """ Abstract base for fixed physical parameters. Parameters ---------- value : float Fixed numeric value for the parameter. name : str Unique identifier for the variable. log_transform : bool, optional If True, store in log-space for positivity constraint and apply exp() when retrieving value, by default False. non_negative : bool, optional If True, ensures value cannot be negative, by default True. Only enforced when log_transform=False. Attributes ---------- value : float The fixed parameter value. name : str Variable name in the computation graph. log_transform : bool Whether to apply log transform for positivity. non_negative : bool Whether negative values are allowed. trainable : bool Always False for fixed parameters. Examples -------- >>> param = FixedGammaW(value=9810.0) >>> value = param.get_value() """ def __init__( self, value: float, name: str, log_transform: bool = False, non_negative: bool = True, **kws # for future extension ): if not isinstance(value, (float, int)): raise TypeError( f"Value for {self.__class__.__name__} " f"must be a float, got {type(value).__name__}" ) # Validate constraints if log_transform and value <= 0: raise ValueError( f"{self.__class__.__name__} value must be " "strictly positive for log transform." ) if non_negative and value < 0 and not log_transform: raise ValueError( f"{self.__class__.__name__} value must be " "non-negative when non_negative=True." ) self.value = float(value) self.name = name self.log_transform = log_transform self.non_negative = non_negative self.trainable = False # Fixed parameters are never trainable self._variable = self._create_variable() def _create_variable(self) -> Union[Variable, Tensor, float]: """ Internal: create tf.Variable or fallback value for fixed parameter. Returns ------- Union[Variable, Tensor, float] Configured fixed variable or numeric. """ if _BACKEND == "tensorflow": val = self.value if self.log_transform: val = tf.math.log(val) return tf.Variable( initial_value=tf.cast(val, dtype=tf.float32), trainable=False, # Explicitly non-trainable name=self.name ) # NumPy fallback return ( np.log(self.value) if self.log_transform else self.value ) def get_value(self) -> Union[Tensor, float]: """ Retrieve the fixed parameter value. Returns ------- Union[Tensor, float] The parameter value, with exp() applied if log_transform=True. """ if _BACKEND == "tensorflow": if self.log_transform: return tf.exp(self._variable) return self._variable # NumPy fallback if self.log_transform: return float(np.exp(self._variable)) return float(self._variable) def get_config(self) -> Dict[str, Any]: """ Return a JSON-serialisable dict for tf.keras serialization. """ return { "value": self.value, "name": self.name, "log_transform": self.log_transform, "non_negative": self.non_negative, "__class_name__": self.__class__.__name__, } @classmethod def from_config(cls, config: Dict[str, Any]) -> "BaseFixed": """ Re-instantiate from configuration dict. """ kwargs = { k: v for k, v in config.items() if k in {"value", "name", "log_transform", "non_negative"} } return cls(**kwargs) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(value={self.value}, " f"name={self.name}, log_transform={self.log_transform}, " f"non_negative={self.non_negative})" ) @register_keras_serializable("fusionlab.params", name="FixedGammaW") class FixedGammaW(BaseFixed): """ Fixed scalar for the (effective) unit weight of water :math:`\\gamma_w`. Used in :math:`s_{\\mathrm{eq}}(h) = m_v\\,\\gamma_w\\,\\Delta h\\,H`. Treated as a constant (non-trainable); in most applications :math:`\\gamma_w \\approx 9{,}810\\ \\mathrm{N\\,m^{-3}}`. Internally we keep :math:`\\log(\\gamma_w)` for numerical stability and return :math:`\\gamma_w` via :meth:`get_value`. Parameters ---------- value : float, default=9810.0 Value for :math:`\gamma_w` [N m^-3]. Must be positive. name : str, optional Variable name. non_negative : bool, default=True Ensures the value cannot be negative. """ def __init__( self, value: float = 9810.0, # Approx. 1000 kg/m^3 * 9.81 m/s^2 name: Optional[str] = None, non_negative: bool = True, **kws ): # gamma_w must be positive, so enforce log_transform for stability kws.pop ("log_transform", None) super().__init__( value=value, name=name or "fixed_gamma_w", log_transform=True, # gamma_w must always be positive non_negative=non_negative, **kws ) @register_keras_serializable("fusionlab.params", name="FixedHRef") class FixedHRef(BaseFixed): r""" Reference head configuration :math:`h_{\mathrm{ref}}` for drawdown. Drawdown convention used in GeoPrior: :math:`\Delta h = \max(h_{\mathrm{ref}} - h, 0)`. This is a modelling datum (not a material parameter). In regional hydrogeology it may represent a pre-development head, a long-term mean head, or (recommended here) a rolling baseline derived from the last observed historical head at forecast start. Parameters ---------- value : float or None, default=0.0 Fallback reference head [m] used when mode="auto" cannot be resolved. If None, defaults to 0.0. mode : {"auto", "fixed"}, default="auto" - "auto": derive :math:`h_{\mathrm{ref}}` per batch from the last historical groundwater observation (preferred). - "fixed": always use `value` as a global datum. name : str, optional Variable name. non_negative : bool, default=False Allow negative values since heads can be negative depending on datum. """ def __init__( self, value: Optional[float] = 0.0, mode: Literal["auto", "fixed"] = "auto", name: Optional[str] = None, non_negative: bool = False, **kws, ): kws.pop("log_transform", None) mode = "auto" if mode is None else str(mode).strip().lower() if mode not in ("auto", "fixed"): raise ValueError( f"Invalid mode={mode!r}. Expected 'auto' or 'fixed'." ) if value is None: value = 0.0 self.mode = mode super().__init__( value=float(value), name=name or "fixed_h_ref", log_transform=False, non_negative=non_negative, **kws, ) def get_config(self): cfg = super().get_config() cfg.update({"mode": self.mode}) return cfg
[docs] @register_keras_serializable("fusionlab.params", name="resolve_physical_param") def resolve_physical_param( param: Any, name: Optional[str] = None, *, serialize: bool = False, status: Optional[str] = None, param_type: Optional[str] = None, log_transform: Optional[bool] = None, non_negative: Optional[bool] = None, trainable: Optional[bool] = None, **additional_kwargs ) -> Union[Tensor, float, Dict, BaseLearnable, BaseFixed]: r""" Normalize a physical-parameter descriptor with enhanced flexibility. The helper converts *param* into: - A concrete value (float/tf.Tensor) for runtime use - A parameter wrapper (BaseLearnable/BaseFixed) when appropriate - A JSON-serializable dict when ``serialize=True`` Parameters ---------- param : float | int | BaseLearnable | BaseFixed | str | Dict Raw descriptor. Can be: - Plain number: treated as fixed or learnable based on status - Wrapped parameter (BaseLearnable/BaseFixed): forwarded as-is - String: "learnable" or "fixed" to create wrapper with defaults - Dict: configuration for parameter creation name : str, optional Parameter identifier used for: - Variable naming in TensorFlow backend - Type inference when creating wrappers serialize : bool, default False Return configuration dict instead of concrete value. status : {'learnable', 'fixed', 'auto', None}, optional Global override: - 'learnable': force creation of learnable wrapper - 'fixed': force creation of fixed wrapper - 'auto': infer from param type - None: use param's inherent behavior param_type : str, optional Explicit parameter type. Overrides name-based inference. Options: 'K', 'Ss', 'Q', 'MV', 'Kappa', 'GammaW', 'HRef' log_transform : bool, optional Force log-space transformation for positivity. non_negative : bool, optional Force non-negativity constraint. trainable : bool, optional Override trainable flag (only for learnable params). **additional_kwargs Additional parameters passed to wrapper constructors. Returns ------- Tensor | float | Dict | BaseLearnable | BaseFixed Concrete value, wrapper instance, or serialized configuration. Raises ------ TypeError If param is of unsupported type. ValueError If parameter type cannot be inferred or constraints are violated. Examples -------- >>> from fusionlab.params import resolve_physical_param >>> # Basic usage with type inference from name >>> resolve_physical_param(1e-4, name="K", status="learnable") LearnableK(initial_value=0.0001, trainable=True) >>> # Explicit parameter type >>> resolve_physical_param(0.5, param_type="MV", status="learnable") LearnableMV(initial_value=0.5, trainable=True) >>> # Fixed parameter with custom constraints >>> resolve_physical_param(9810.0, param_type="GammaW", non_negative=True) FixedGammaW(value=9810.0, non_negative=True) >>> # From configuration dict >>> config = {"class": "LearnableK", "initial_value": 0.5, "trainable": True} >>> resolve_physical_param(config) LearnableK(initial_value=0.5, trainable=True) >>> # Serialization >>> k = LearnableK(0.5) >>> resolve_physical_param(k, serialize=True) {'class': 'LearnableK', 'initial_value': 0.5, ...} """ # 1. Serialization Branch if serialize: if isinstance(param, (BaseLearnable, BaseFixed)): config = param.get_config() config["class"] = param.__class__.__name__ return config elif isinstance(param, (float, int)): return { "class": "float", "value": float(param), "learnable": False } elif isinstance(param, dict) and "class" in param: return param # Already serialized else: raise TypeError( f"Cannot serialize parameter of type {type(param).__name__}" ) # 2. Configuration Dict Processing if isinstance(param, dict): if "class" not in param: raise ValueError("Configuration dict must contain 'class' key") class_name = param["class"] config = dict(param) config.pop("class", None) # Map class names to constructors wrapper_classes = { # Learnable parameters "LearnableK": LearnableK, "LearnableSs": LearnableSs, "LearnableQ": LearnableQ, "LearnableMV": LearnableMV, "LearnableKappa": LearnableKappa, # Fixed parameters "FixedGammaW": FixedGammaW, "FixedHRef": FixedHRef, # Legacy parameters "LearnableC": LearnableC, "FixedC": FixedC, "DisabledC": DisabledC } if class_name not in wrapper_classes: # Handle plain float values if class_name == "float": return float(config.get("value", 0.0)) raise ValueError(f"Unknown parameter class: {class_name}") return wrapper_classes[class_name](**config) # 3. String Parameter Processing if isinstance(param, str): if param.lower() in ("learnable", "fixed"): # Use string as status override status = param.lower() param = 1.0 # Default value for wrapper creation else: try: # Try to parse as numeric string param = float(param) except ValueError: raise ValueError( f"String parameter must be numeric or 'learnable'/'fixed', " f"got '{param}'" ) # 4. Type Inference and Wrapper Mapping # Determine parameter type resolved_param_type = param_type or _infer_param_type_from_name(name) # Map parameter types to wrapper classes learnable_wrappers = { "K": LearnableK, "Ss": LearnableSs, "Q": LearnableQ, "MV": LearnableMV, "Kappa": LearnableKappa, "C": LearnableC # Legacy support } fixed_wrappers = { "GammaW": FixedGammaW, "HRef": FixedHRef, "C": FixedC # Legacy support } # 5. Status-Based Processing resolved_status = status or "auto" # Handle already wrapped parameters if isinstance(param, (BaseLearnable, BaseFixed)): if resolved_status == "auto": return param elif resolved_status == "learnable" and isinstance(param, BaseFixed): # Convert fixed to learnable if requested return _convert_fixed_to_learnable( param, resolved_param_type, name, **additional_kwargs) elif resolved_status == "fixed" and isinstance(param, BaseLearnable): # Convert learnable to fixed if requested return _convert_learnable_to_fixed( param, resolved_param_type, name, **additional_kwargs) else: return param # 6. Numeric Parameter Processing if isinstance(param, (float, int)): numeric_value = float(param) # Apply status resolution if resolved_status == "learnable": return _create_learnable_wrapper( numeric_value, resolved_param_type, name, learnable_wrappers, log_transform, non_negative, trainable, **additional_kwargs ) elif resolved_status == "fixed": return _create_fixed_wrapper( numeric_value, resolved_param_type, name, fixed_wrappers, log_transform, non_negative, **additional_kwargs ) else: # auto or None # Return as concrete value if _BACKEND == "tensorflow": return tf.constant(numeric_value, dtype=tf.float32) return numeric_value # 7. Fallback for Unhandled Types raise TypeError( f"Parameter must be float, int, BaseLearnable, BaseFixed, dict, or str; " f"got {type(param).__name__}" )
def _infer_param_type_from_name(name: Optional[str]) -> str: """Infer parameter type from name using flexible matching.""" if not name: return "Unknown" name_upper = name.upper() # Flexible type matching type_patterns = { "K": ["K", "CONDUCTIVITY", "HYDRAULIC_CONDUCTIVITY"], "Ss": ["SS", "SPECIFIC_STORAGE", "STORAGE"], "Q": ["Q", "SOURCE", "SINK", "SOURCE_SINK"], "MV": ["MV", "M_V", "COMPRESSIBILITY", "VOLUME_COMPRESSIBILITY"], "Kappa": ["KAPPA", "CONSISTENCY", "PRIOR"], "GammaW": ["GAMMA_W", "GAMMAW", "UNIT_WEIGHT", "WATER_WEIGHT"], "HRef": ["H_REF", "HREF", "REFERENCE_HEAD", "REF_HEAD"], "C": ["C", "COEFFICIENT", "PHYSICS_COEFF"] # Legacy } for param_type, patterns in type_patterns.items(): if any(pattern in name_upper for pattern in patterns): return param_type return "Unknown" def _create_learnable_wrapper( value: float, param_type: str, name: Optional[str], wrapper_map: Dict[str, Type[BaseLearnable]], log_transform: Optional[bool], non_negative: Optional[bool], trainable: Optional[bool], **kwargs ) -> BaseLearnable: """Create a learnable parameter wrapper.""" if param_type not in wrapper_map: raise ValueError( f"Cannot create learnable wrapper for parameter type '{param_type}'. " f"Available types: {list(wrapper_map.keys())}" ) wrapper_class = wrapper_map[param_type] # Set default parameters based on type default_params = { "K": {"initial_value": value, "log_transform": True, "trainable": True}, "Ss": {"initial_value": value, "log_transform": True, "trainable": True}, "Q": {"initial_value": value, "log_transform": False, "trainable": True}, "MV": {"initial_value": value, "log_transform": True, "trainable": True}, "Kappa": {"initial_value": value, "log_transform": True, "trainable": True}, "C": {"initial_value": value} # Legacy } params = default_params.get(param_type, {"initial_value": value}) # Apply overrides if log_transform is not None: params["log_transform"] = log_transform if trainable is not None: params["trainable"] = trainable if name: params["name"] = name params.update(kwargs) return wrapper_class(**params) def _create_fixed_wrapper( value: float, param_type: str, name: Optional[str], wrapper_map: Dict[str, Type[BaseFixed]], log_transform: Optional[bool], non_negative: Optional[bool], **kwargs ) -> BaseFixed: """Create a fixed parameter wrapper.""" if param_type not in wrapper_map: # For unsupported fixed types, return as concrete value if _BACKEND == "tensorflow": return tf.constant(value, dtype=tf.float32) return value wrapper_class = wrapper_map[param_type] # Set default parameters based on type default_params = { "GammaW": {"value": value, "log_transform": True, "non_negative": True}, "HRef": {"value": value, "log_transform": False, "non_negative": False}, "C": {"value": value} # Legacy } params = default_params.get(param_type, {"value": value}) # Apply overrides if log_transform is not None: params["log_transform"] = log_transform if non_negative is not None: params["non_negative"] = non_negative if name: params["name"] = name params.update(kwargs) return wrapper_class(**params) def _convert_fixed_to_learnable( fixed_param: BaseFixed, param_type: str, name: Optional[str], **kwargs ) -> BaseLearnable: """Convert a fixed parameter to learnable.""" learnable_wrappers = { "K": LearnableK, "Ss": LearnableSs, "Q": LearnableQ, "MV": LearnableMV, "Kappa": LearnableKappa } if param_type not in learnable_wrappers: raise ValueError( "Cannot convert fixed parameter to" f" learnable for type '{param_type}'") wrapper_class = learnable_wrappers[param_type] params = { "initial_value": fixed_param.value, "name": name or fixed_param.name, "trainable": True } params.update(kwargs) return wrapper_class(**params) def _convert_learnable_to_fixed( learnable_param: BaseLearnable, param_type: str, name: Optional[str], **kwargs ) -> BaseFixed: """Convert a learnable parameter to fixed.""" fixed_wrappers = { "GammaW": FixedGammaW, "HRef": FixedHRef } if param_type not in fixed_wrappers: # For unsupported conversions, return as concrete value return learnable_param.get_value() wrapper_class = fixed_wrappers[param_type] params = { "value": learnable_param.initial_value, "name": name or learnable_param.name } params.update(kwargs) return wrapper_class(**params)