fusionlab.nn.components.create_causal_mask

fusionlab.nn.components.create_causal_mask(size)[source]

Creates a causal attention mask of shape [1,1,seq_len,seq_len] where mask[0,0,i,j] = 1.0 if j > i else 0.0.

Parameters:

size (int | Tensor)

Return type:

Tensor