Source code for fusionlab.nn.components._attention_utils

# -*- coding: utf-8 -*-
# License: BSD-3-Clause
# Author: LKouadio <etanoyau@gmail.com>
"""
Utility helpers for attention masks.

- create_causal_mask(seq_len)
- combine_masks(mask_a, mask_b, mode='and')
"""

from __future__ import annotations
from typing import Union, Optional

from ._config import (
    Tensor,
    tf_cast, tf_int32, tf_range, tf_expand_dims, tf_greater,
    tf_float32, tf_bool, tf_logical_and, tf_logical_or, tf_logical_not,
    tf_linalg, tf_ones, 
)

__all__ = ["create_causal_mask", "combine_masks"]


[docs] def create_causal_mask(size: Union[int,Tensor]) -> Tensor: """ Creates a causal attention mask of shape [1,1,seq_len,seq_len] where mask[0,0,i,j] = 1.0 if j > i else 0.0. """ # Make sure size is a 0-D int32 Tensor size = tf_cast(size, tf_int32) # Build a vector [0,1,2,...,size-1] idxs = tf_range(size) # shape: [size] # Compare row < col for every pair (i,j) # row_idxs: [size,1], col_idxs: [1,size] row_idxs = tf_expand_dims(idxs, 1) # [size,1] col_idxs = tf_expand_dims(idxs, 0) # [1,size] # mask2d[i,j] = True if j > i, else False mask2d = tf_greater(col_idxs, row_idxs) # [size,size], dtype=bool # Cast to float (1.0 for masked positions, 0.0 elsewhere) mask2d = tf_cast(mask2d, tf_float32) # [size,size] # Expand to [1,1,size,size] so it broadcasts over (batch, heads) mask = tf_expand_dims(tf_expand_dims(mask2d, 0), 1) return mask
def create_causal_mask_(size: Union[int, Tensor]) -> Tensor: """ Build a causal mask shaped (1, 1, L, L) where entries are 1.0 when j > i (future positions) and 0.0 otherwise. Broadcasts over batch and heads in Keras MHA. Parameters ---------- size : int or Tensor Sequence length. Returns ------- Tensor Float mask (1.0 = masked) of shape (1, 1, L, L). """ size = tf_cast(size, tf_int32) idxs = tf_range(size) # [L] row = tf_expand_dims(idxs, 1) # [L,1] col = tf_expand_dims(idxs, 0) # [1,L] mask2d = tf_greater(col, row) # True if j > i mask2d = tf_cast(mask2d, tf_float32) # [L,L] return tf_expand_dims(tf_expand_dims(mask2d, 0), 1) # [1,1,L,L] def combine_masks( mask_a: Optional[Tensor], mask_b: Optional[Tensor], *, mode: str = "and", invert_b: bool = False, ) -> Optional[Tensor]: """ Combine two boolean/0-1 masks into one. Parameters ---------- mask_a, mask_b : Tensor or None Any broadcastable masks. If one is None, the other is returned. Masks may be bool or 0/1 floats. mode : {'and','or','xor'} Logical op used to merge masks. invert_b : bool, default False If True, logical-not is applied to mask_b before combining. Returns ------- Tensor or None Combined mask (bool). None if both inputs are None. """ if mask_a is None and mask_b is None: return None def _to_bool(x: Tensor) -> Tensor: return tf_cast(x, tf_bool) if mask_a is None: mb = _to_bool(mask_b) return tf_logical_not(mb) if invert_b else mb if mask_b is None: return _to_bool(mask_a) a = _to_bool(mask_a) b = _to_bool(mask_b) if invert_b: b = tf_logical_not(b) if mode == "and": return tf_logical_and(a, b) if mode == "or": return tf_logical_or(a, b) if mode == "xor": # xor = (a or b) and not (a and b) return tf_logical_and(tf_logical_or(a, b), tf_logical_not(tf_logical_and(a, b))) raise ValueError("mode must be 'and', 'or', or 'xor'.") # def create_causal_mask(size: Union[int, Tensor]) -> Tensor: # """Creates a causal attention mask of shape [1,1,seq_len,seq_len].""" # # ensure `size` is an int32 Tensor # size = tf_cast(size, tf_int32) # # build shape as a Tensor so we don't capture Python tuples of Tensors # shape = tf_stack([size, size]) # ones = tf_ones(shape, dtype=tf_float32) # # make the [seq_len, seq_len] causal matrix # mask2d = 1.0 - tf_linalg.band_part(ones, -1, 0) # # now expand batch dim then head dim → [1,1,seq_len,seq_len] # mask = tf_expand_dims(mask2d, 0) # → [1, seq_len, seq_len] # mask = tf_expand_dims(mask, 1) # → [1, 1, seq_len, seq_len] # return mask def _create_causal_mask(size: Union[int, Tensor]) -> Tensor: """Creates a causal attention mask for the decoder.""" mask = 1 - tf_linalg.band_part(tf_ones((size, size)), -1, 0) # Add batch and head dimensions for broadcasting return mask[tf_expand_dims(tf_range(size), 0), :] # (1, 1, seq_len, seq_len) -> Keras MHA expects (B, T, T) or (B, N_heads, T, T) # TF MHA expects (B, N_heads, T, T) # Let's make it (1,1,T,T) for TF MHA layer, it will broadcast # # Keras MHA expects mask shape (batch_size, num_heads, query_length, key_length) # # or (batch_size, query_length, key_length) # # For causal, query_length == key_length == size # return tf_expand_dims(tf_expand_dims( # 1 - tf_linalg.band_part(tf_ones((size, size)), -1, 0), axis=0), axis=0)