Source code for fusionlab.nn.components.attention

# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
Attention-centric layers for FusionLab.
"""

from __future__ import annotations

from typing import Optional
from numbers import Real, Integral

from ...api.property import NNLearner
from ...compat.sklearn import validate_params, Interval
from ...utils.deps_utils import ensure_pkg
from ._config import (                                  
    KERAS_BACKEND, 
    DEP_MSG, 
    _logger,
    LayerNormalization,
    MultiHeadAttention, 
    Dropout, 
    Dense, 
    Layer, 
    Tensor, 
    register_keras_serializable,
    tf_shape, 
    tf_expand_dims, 
    tf_tile, 
    tf_bool, 
    tf_add, 
    tf_cast,
    tf_logical_and, 
    tf_ones_like, 
    tf_ones, 
    tf_autograph 
)
from .gating_norm import GatedResidualNetwork
from .misc import Activation



__all__ = [
    "TemporalAttentionLayer",
    "CrossAttention",
    "MemoryAugmentedAttention",
    "HierarchicalAttention",
    "ExplainableAttention",
    "MultiResolutionAttentionFusion",
]

[docs] @register_keras_serializable( 'fusionlab.nn.components', name="TemporalAttentionLayer" ) class TemporalAttentionLayer(Layer): """Temporal Attention Layer conditioning query with context."""
[docs] @validate_params({ "units": [Interval(Integral, 0, None, closed='left')], "num_heads": [Interval(Integral, 0, None, closed='left')], "dropout_rate": [Interval(Real, 0, 1, closed="both")], "use_batch_norm": [bool], }) @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, units: int, num_heads: int, dropout_rate: float = 0.0, activation: str = 'elu', use_batch_norm: bool = False, **kwargs ): """Initializes the TemporalAttentionLayer.""" super().__init__(**kwargs) self.units = units self.num_heads = num_heads self.dropout_rate = dropout_rate self.use_batch_norm = use_batch_norm self.activation_str = Activation(activation).activation_str # --- Define Internal Layers --- self.multi_head_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units, dropout=dropout_rate, name="mha" ) self.dropout = Dropout(dropout_rate, name="attn_dropout") self.layer_norm1 = LayerNormalization(name="layer_norm_1") # GRN to process the input context_vector # Ensure this is a single instance, passing the activation string self.context_grn = GatedResidualNetwork( units=units, # Output matches main path 'units' dropout_rate=dropout_rate, activation=self.activation_str, use_batch_norm=self.use_batch_norm, name="context_grn" # Note: GRN's internal activation handling should be fixed ) # Final GRN (position-wise feedforward) # Ensure this is also a single instance self.output_grn = GatedResidualNetwork( units=units, dropout_rate=dropout_rate, activation=self.activation_str, use_batch_norm=self.use_batch_norm, name="output_grn" )
[docs] def build(self, input_shape): """Builds internal layers, especially GRNs.""" # input_shape corresponds to the main 'inputs' tensor (B, T, U) if not isinstance(input_shape, (list, tuple)): # If only main input shape is passed (common) main_input_shape = tuple(input_shape) elif len(input_shape) == 2: # [inputs_shape, context_shape] rarelly happended main_input_shape = tuple(input_shape[0]) # Optionally build context_grn if context_shape is known context_shape = tuple(input_shape[1]) if not self.context_grn.built: self.context_grn.build(context_shape) else: raise ValueError( "Unexpected input_shape format for build.") if len(main_input_shape) < 3: raise ValueError( "TemporalAttentionLayer expects input rank >= 3") # Define expected input shape for output_grn # It receives output from layer_norm1, which has same shape as input output_grn_input_shape = main_input_shape # Explicitly build the output GRN if not already built if not self.output_grn.built: self.output_grn.build(output_grn_input_shape) # Developer comment: Explicitly built output_grn. # Build context_grn lazily during call or here # Call the parent build method AFTER building sub-layers super().build(input_shape)
# Developer comment: Layer built status should now be True.
[docs] def call(self, inputs, context_vector=None, training=False): """Forward pass of the temporal attention layer.""" # Input shapes: inputs=(B, T, U), context_vector=(B, U_ctx) query = inputs # Default query processed_context = None # --- Process Context Vector (if provided) --- if context_vector is not None: # Pass context_vector as the main input 'x' to context_grn processed_context = self.context_grn( x=context_vector, context=None, # No nested context for the context_grn itself training=training ) # Output shape: (B, units) # Expand context across time: (B, units) -> (B, 1, units) context_expanded = tf_expand_dims(processed_context, axis=1) # Add to inputs (broadcasting handles time dimension) query = tf_add(inputs, context_expanded) # Comment: Query now incorporates static context. # --- Multi-Head Self-Attention --- attn_output = self.multi_head_attention( query=query, value=inputs, key=inputs, training=training ) # Shape: (B, T, units) # --- Add & Norm (First Residual Connection) --- attn_output_dropout = self.dropout(attn_output, training=training) # Residual connection uses original 'inputs' x_attn = self.layer_norm1(tf_add(inputs, attn_output_dropout)) # Shape: (B, T, units) # --- Position-wise Feedforward (Final GRN) --- # This GRN takes the output of the attention block as input 'x' # It does not receive the external 'context_vector' here. # --- DEBUG lines --- _logger.debug("\nDEBUG>> About to call self.output_grn") _logger.debug( "DEBUG>> Type of self.output_grn:" f" {type(self.output_grn)}") _logger.debug( "DEBUG>> Is self.output_grn callable:" f" {callable(self.output_grn)}") try: # Try accessing an attribute expected on a Keras layer _logger.debug( "DEBUG>> self.output_grn name:" f" {self.output_grn.name}") _logger.debug( "DEBUG>> self.output_grn built status:" f" {self.output_grn.built}") except AttributeError as ae: _logger.debug( "DEBUG>> Failed to access attributes" f" of self.output_grn: {ae}") _logger.debug( f"DEBUG>> Input x_attn shape: {tf_shape(x_attn)}\n") # --- End DEBUG lines --- output = self.output_grn( x=x_attn, context=None, # No external context for the final GRN training=training ) # Shape: (B, T, units) return output
[docs] def get_config(self): """Returns the layer configuration.""" config = super().get_config() config.update({ 'units': self.units, 'num_heads': self.num_heads, 'dropout_rate': self.dropout_rate, 'activation': self.activation_str, 'use_batch_norm': self.use_batch_norm, }) return config
[docs] @classmethod def from_config(cls, config): """Creates layer from its config.""" return cls(**config)
@register_keras_serializable( 'fusionlab.nn.components', name="CrossAttention_" ) class CrossAttention_(Layer, NNLearner): r""" CrossAttention layer that attends one source sequence to another [1]_. This layer transforms two input sources, ``source1`` and ``source2``, into a shared dimensionality via separate dense layers, then applies multi-head attention using ``source1`` as the query and ``source2`` as both key and value. The output shape depends on the specified ``units``. .. math:: \mathbf{H}_{\text{out}} = \text{MHA}( \mathbf{W}_{1}\,\mathbf{S}_1,\, \mathbf{W}_{2}\,\mathbf{S}_2,\, \mathbf{W}_{2}\,\mathbf{S}_2 ) where :math:`\mathbf{S}_1` and :math:`\mathbf{S}_2` are the two source sequences. Parameters ---------- units : int Dimensionality for the internal projections of the query/key/value in multi-head attention. num_heads : int Number of attention heads. Notes ----- Cross attention is particularly useful when focusing on how one sequence (the query) relates to another (the key/value). For example, in multi-modal time series settings, one might attend dynamic covariates to static ones or vice versa. Methods ------- call(`inputs`, training=False) Forward pass of the cross-attention layer. get_config() Returns the configuration dictionary for serialization. from_config(`config`) Creates a new layer from the given config. Examples -------- >>> from fusionlab.nn.components import CrossAttention >>> import tensorflow as tf >>> # Two sequences of shape (batch_size, time_steps, features) >>> source1 = tf.random.normal((32, 10, 64)) >>> source2 = tf.random.normal((32, 10, 64)) >>> # Instantiate the CrossAttention layer >>> cross_attn = CrossAttention(units=64, num_heads=4) >>> # Forward pass >>> outputs = cross_attn([source1, source2]) See Also -------- HierarchicalAttention Another attention-based layer focusing on short/long-term sequences. MemoryAugmentedAttention Uses a learned memory matrix to enhance representations. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): r""" Initialize the CrossAttention layer. Parameters ---------- units : int Number of output units for the internal Dense projections and multi-head attention dimension. num_heads : int Number of attention heads to use in the multi-head attention module. """ super().__init__() self.units = units # Dense layers to project each source self.source1_dense = Dense(units) self.source2_dense = Dense(units) # Multi-head attention self.cross_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False): r""" Forward pass of CrossAttention. Parameters ---------- ``inputs`` : list of tf.Tensor A list [source1, source2], each of shape (batch_size, time_steps, features). training : bool, optional Indicates if the layer is in training mode (for dropout, if any). Defaults to ``False``. Returns ------- tf.Tensor A tensor of shape (batch_size, time_steps, units) representing cross-attended features. """ source1, source2 = inputs # Project each source source1 = self.source1_dense(source1) source2 = self.source2_dense(source2) # Apply cross attention return self.cross_attention( query=source1, value=source2, key=source2 ) def get_config(self): r""" Returns configuration dictionary for this layer. Returns ------- dict Configuration dictionary, including 'units'. """ config = super().get_config().copy() config.update({'units': self.units}) return config @classmethod def from_config(cls, config): r""" Create a new CrossAttention layer from the given config dictionary. Parameters ---------- ``config`` : dict Configuration as returned by ``get_config``. Returns ------- CrossAttention A new instance of CrossAttention. """ return cls(**config)
[docs] @register_keras_serializable( 'fusionlab.nn.components', name="CrossAttention" ) class CrossAttention(Layer, NNLearner): r""" CrossAttention that attends ``source1`` (query) to ``source2`` (key/value) with optional masks. attention_mask : Tensor, optional Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras ``MultiHeadAttention``. query_mask, value_mask : Tensor, optional 1D/2D masks (B, Tq) or (B, Tv). If provided and ``attention_mask`` is None, they are combined to form (B, Tq, Tv). use_causal_mask : bool Forwarded to MHA. Default False. """
[docs] @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units self.source1_dense = Dense(units) self.source2_dense = Dense(units) self.cross_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] @tf_autograph.experimental.do_not_convert def call( self, inputs, training: bool = False, *, attention_mask: Optional[Tensor] = None, query_mask: Optional[Tensor] = None, value_mask: Optional[Tensor] = None, use_causal_mask: bool = False, **kwargs, ): r""" Forward pass of CrossAttention. Parameters ---------- ``inputs`` : list of tf.Tensor A list [source1, source2], each of shape (batch_size, time_steps, features). training : bool, optional Indicates if the layer is in training mode (for dropout, if any). Defaults to ``False``. attention_mask : Tensor, optional Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras ``MultiHeadAttention``. query_mask, value_mask : Tensor, optional 1D/2D masks (B, Tq) or (B, Tv). If provided and ``attention_mask`` is None, they are combined to form (B, Tq, Tv). use_causal_mask : bool Forwarded to MHA. Default False. Returns ------- tf.Tensor A tensor of shape (batch_size, time_steps, units) representing cross-attended features. """ source1, source2 = inputs # shapes: (B, Tq, Fq), (B, Tv, Fv) # Project to common dim q = self.source1_dense(source1) kv = self.source2_dense(source2) # Build attention_mask if needed if attention_mask is None and (query_mask is not None or value_mask is not None): # default to all True if one side is None if query_mask is None: query_mask = tf_ones_like( source1[..., 0], dtype=tf_bool) if value_mask is None: value_mask = tf_ones_like( source2[..., 0], dtype=tf_bool) qm = tf_expand_dims(tf_cast(query_mask, tf_bool), axis=-1) vm = tf_expand_dims(tf_cast(value_mask, tf_bool), axis=1) # (B, Tq, 1) & (B, 1, Tv) -> (B, Tq, Tv) attention_mask = tf_logical_and(qm, vm) return self.cross_attention( query=q, key=kv, value=kv, attention_mask=attention_mask, use_causal_mask=use_causal_mask, training=training, )
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update({'units': self.units}) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
[docs] @register_keras_serializable( 'fusionlab.nn.components', name="MemoryAugmentedAttention" ) class MemoryAugmentedAttention(Layer, NNLearner): r"""Memory-augmented attention with optional masking."""
[docs] @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, memory_size: int, num_heads: int): super().__init__() self.units = units self.memory_size = memory_size self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] def build(self, input_shape): self.memory = self.add_weight( name="memory", shape=(self.memory_size, self.units), initializer="zeros", trainable=True, ) super().build(input_shape)
[docs] @tf_autograph.experimental.do_not_convert def call( self, inputs, training: bool = False, *, attention_mask: Optional[Tensor] = None, query_mask: Optional[Tensor] = None, value_mask: Optional[Tensor] = None, use_causal_mask: bool = False, **kwargs, ): # inputs: (B, T, U) batch_size = tf_shape(inputs)[0] mem = tf_expand_dims(self.memory, 0) # (1, M, U) mem = tf_tile(mem, [batch_size, 1, 1]) # (B, M, U) # Build attention_mask if only per-sequence masks given if attention_mask is None and (query_mask is not None or value_mask is not None): if query_mask is None: query_mask = tf_ones_like(inputs[..., 0], dtype=tf_bool) if value_mask is None: value_mask = tf_ones( (batch_size, self.memory_size), dtype=tf_bool ) qm = tf_expand_dims(tf_cast(query_mask, tf_bool), -1) # (B,T,1) vm = tf_expand_dims(tf_cast(value_mask, tf_bool), 1) # (B,1,M) attention_mask = tf_logical_and(qm, vm) # (B,T,M) mem_att = self.attention( query=inputs, key=mem, value=mem, attention_mask=attention_mask, use_causal_mask=use_causal_mask, training=training, ) return mem_att + inputs
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update({'units': self.units, 'memory_size': self.memory_size}) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
@register_keras_serializable( 'fusionlab.nn.components', name="HierarchicalAttention_" ) class HierarchicalAttention_(Layer, NNLearner): r""" Hierarchical Attention layer that processes short-term and long-term sequences separately using multi-head attention, then combines their outputs [1]_. This allows the model to focus on different aspects of the data in short-term and long-term contexts and aggregate the attention outputs for a more comprehensive representation. .. math:: \mathbf{Z} = \text{MHA}(\mathbf{X}_{s}) + \text{MHA}(\mathbf{X}_{l}) where :math:`\mathbf{X}_{s}` and :math:`\mathbf{X}_{l}` are the short- and long-term sequences, respectively. Parameters ---------- units : int Dimensionality of the projection for the attention keys, queries, and values. num_heads : int Number of attention heads to use in each multi-head attention sub-layer. Notes ----- The output shape depends on the last dimension in the short and long sequences, projected to `units`. The final output is the sum of the short-term attention output and the long-term attention output. Methods ------- call(`inputs`, training=False) Forward pass. Expects a list `[short_term, long_term]` with shapes (B, T, D_s) and (B, T, D_l). get_config() Returns configuration dictionary for serialization. from_config(`config`) Recreates the layer from a config dict. Examples -------- >>> from fusionlab.nn.components import HierarchicalAttention >>> import tensorflow as tf >>> # Suppose short_term and long_term have ... # shape (batch_size, time_steps, features). >>> short_term = tf.random.normal((32, 10, 64)) >>> long_term = tf.random.normal((32, 10, 64)) >>> # Instantiate hierarchical attention >>> ha = HierarchicalAttention(units=64, num_heads=4) >>> # Forward pass >>> outputs = ha([short_term, long_term]) See Also -------- MultiModalEmbedding Can precede attention by embedding multiple sources of input. LearnedNormalization Can be applied to short_term and long_term sequences prior to attention. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units # Dense layers for short/long sequences self.short_term_dense = Dense(units) self.long_term_dense = Dense(units) # Multi-head attention for short/long self.short_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) self.long_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False): r""" Forward pass of the HierarchicalAttention. Parameters ---------- ``inputs`` : list of tf.Tensor A list `[short_term, long_term]`. Each tensor should have shape :math:`(B, T, D)`. training : bool, optional Indicates whether the layer is in training mode. Defaults to ``False``. Returns ------- tf.Tensor A tensor of shape :math:`(B, T, U)`, where `U = units`, representing the combined attention outputs. """ short_term, long_term = inputs # Linear projections to unify # dimensionality short_term = self.short_term_dense( short_term ) long_term = self.long_term_dense( long_term ) # Multi-head attention on short_term short_term_attention = ( self.short_term_attention( short_term, short_term ) ) # Multi-head attention on long_term long_term_attention = ( self.long_term_attention( long_term, long_term ) ) # Combine return short_term_attention + long_term_attention def get_config(self): r""" Returns a dictionary of config parameters for serialization. Returns ------- dict Dictionary with 'units', 'short_term_dense' config, and 'long_term_dense' config. """ config = super().get_config().copy() config.update({ 'units': self.units, 'short_term_dense': self.short_term_dense.get_config(), 'long_term_dense': self.long_term_dense.get_config() }) return config @classmethod def from_config(cls, config): r""" Recreates the HierarchicalAttention layer from a config dictionary. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- HierarchicalAttention A new instance with the specified configuration. """ return cls(**config)
[docs] @register_keras_serializable( 'fusionlab.nn.components', name="HierarchicalAttention" ) class HierarchicalAttention(Layer, NNLearner): r"""Short/long-term MHA with optional masks."""
[docs] @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units self.short_term_dense = Dense(units) self.long_term_dense = Dense(units) self.short_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) self.long_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] @tf_autograph.experimental.do_not_convert def call( self, inputs, training: bool = False, *, short_mask: Optional[Tensor] = None, long_mask: Optional[Tensor] = None, use_causal_mask: bool = False, **kwargs, ): # inputs: [short_term, long_term] short_term, long_term = inputs s = self.short_term_dense(short_term) l = self.long_term_dense(long_term) # Build masks to (B, T, T) if provided as (B,T) def _expand_mask(m): if m is None: return None m = tf_cast(m, tf_bool) qm = tf_expand_dims(m, 1) # (B,1,T) vm = tf_expand_dims(m, 2) # (B,T,1) return tf_logical_and(vm, qm) # (B,T,T) s_mask = _expand_mask(short_mask) l_mask = _expand_mask(long_mask) s_att = self.short_term_attention( query=s, key=s, value=s, attention_mask=s_mask, use_causal_mask=use_causal_mask, training=training, ) l_att = self.long_term_attention( query=l, key=l, value=l, attention_mask=l_mask, use_causal_mask=use_causal_mask, training=training, ) return s_att + l_att
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update({'units': self.units, 'short_term_dense': self.short_term_dense.get_config(), 'long_term_dense': self.long_term_dense.get_config()}) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
[docs] @register_keras_serializable( 'fusionlab.nn.components', name="ExplainableAttention" ) class ExplainableAttention(Layer, NNLearner): r""" ExplainableAttention layer that returns attention scores from multi-head attention [1]_. This layer is useful for interpretability, providing insight into how the attention mechanism focuses on different time steps. .. math:: \mathbf{A} = \text{MHA}(\mathbf{X},\,\mathbf{X}) \rightarrow \text{attention\_scores} Here, :math:`\mathbf{X}` is an input tensor, and ``attention_scores`` is the matrix capturing attention weights. Parameters ---------- num_heads : int Number of heads for multi-head attention. key_dim : int Dimensionality of the query/key projections. Notes ----- Unlike standard layers that return the transformation output, this layer specifically returns the attention score matrix for interpretability. Methods ------- call(`inputs`, training=False) Forward pass that outputs only the attention scores. get_config() Returns the configuration for serialization. from_config(`config`) Creates a new instance from the given config. Examples -------- >>> from fusionlab.nn.components import ExplainableAttention >>> import tensorflow as tf >>> # Suppose we have input of shape (batch_size, time_steps, features) >>> x = tf.random.normal((32, 10, 64)) >>> # Instantiate explainable attention >>> ea = ExplainableAttention(num_heads=4, key_dim=64) >>> # Forward pass returns attention scores: (B, num_heads, T, T) >>> scores = ea(x) See Also -------- CrossAttention Another attention variant for cross-sequence contexts. MultiResolutionAttentionFusion For fusing features via multi-head attention. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """
[docs] @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, num_heads: int, key_dim: int): r""" Initialize the ExplainableAttention layer. Parameters ---------- num_heads : int Number of attention heads. key_dim : int Dimensionality of query/key projections in multi-head attention. """ super().__init__() self.num_heads = num_heads self.key_dim = key_dim # MultiHeadAttention, focusing on returning # the attention scores self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=key_dim )
[docs] @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False): r""" Forward pass that returns only the attention scores. Parameters ---------- ``inputs`` : tf.Tensor Tensor of shape (B, T, D). training : bool, optional Indicates training mode; not used in this layer. Defaults to ``False``. Returns ------- tf.Tensor Attention scores of shape (B, num_heads, T, T). """ _, attention_scores = self.attention( inputs, inputs, return_attention_scores=True ) return attention_scores
[docs] def get_config(self): r""" Returns the layer configuration. Returns ------- dict Dictionary containing 'num_heads' and 'key_dim'. """ config = super().get_config().copy() config.update({ 'num_heads': self.num_heads, 'key_dim': self.key_dim }) return config
[docs] @classmethod def from_config(cls, config): r""" Creates a new instance from the config dictionary. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- ExplainableAttention A new instance of this layer. """ return cls(**config)
[docs] @register_keras_serializable( 'fusionlab.nn.components', name="MultiResolutionAttentionFusion" ) class MultiResolutionAttentionFusion(Layer, NNLearner): r""" MultiResolutionAttentionFusion layer applying multi-head attention fusion over features [1]_. This layer merges or fuses features at different resolutions or sources via multi-head attention. The input is projected to shape `(B, T, D)`, and the output shares the same shape. .. math:: \mathbf{Z} = \text{MHA}(\mathbf{X}, \mathbf{X}) Parameters ---------- units : int Dimension of the key, query, and value projections. num_heads : int Number of attention heads. Notes ----- Typically used in multi-resolution contexts where time steps or multiple feature sets are merged. Methods ------- call(`inputs`, training=False) Forward pass of the multi-head attention layer. get_config() Returns config for serialization. from_config(`config`) Reconstructs the layer from a config. Examples -------- >>> from fusionlab.nn.components import MultiResolutionAttentionFusion >>> import tensorflow as tf >>> x = tf.random.normal((32, 10, 64)) >>> # Instantiate multi-resolution attention >>> mraf = MultiResolutionAttentionFusion( ... units=64, ... num_heads=4 ... ) >>> # Forward pass => (32, 10, 64) >>> y = mraf(x) See Also -------- HierarchicalAttention Combines short and long-term sequences with attention. ExplainableAttention Another attention layer returning attention scores. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """
[docs] @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): r""" Initialize the MultiResolutionAttentionFusion layer. Parameters ---------- units : int Dimensionality for the attention projections. num_heads : int Number of heads for multi-head attention. """ super().__init__() self.units = units self.num_heads = num_heads # MultiHeadAttention instance self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] @tf_autograph.experimental.do_not_convert def call(self, inputs, training=False): r""" Forward pass applying multi-head attention to fuse features. Parameters ---------- ``inputs`` : tf.Tensor Tensor of shape (B, T, D). training : bool, optional Indicates training mode. Defaults to ``False``. Returns ------- tf.Tensor Tensor of shape (B, T, D), representing fused features. """ return self.attention(inputs, inputs)
[docs] def get_config(self): r""" Returns configuration dictionary with 'units' and 'num_heads'. Returns ------- dict Configuration for serialization. """ config = super().get_config().copy() config.update({ 'units': self.units, 'num_heads': self.num_heads }) return config
[docs] @classmethod def from_config(cls, config): r""" Instantiate a new MultiResolutionAttentionFusion layer from config. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- MultiResolutionAttentionFusion A new instance of this layer. """ return cls(**config)