Allows the model to jointly attend to information from different representation subspaces. See reference: Attention Is All You Need
nn_multihead_attention(
embed_dim,
num_heads,
dropout = 0,
bias = TRUE,
add_bias_kv = FALSE,
add_zero_attn = FALSE,
kdim = NULL,
vdim = NULL,
batch_first = FALSE
)
total dimension of the model.
parallel attention heads. Note that embed_dim
will be split
across num_heads
(i.e. each head will have dimension embed_dim %/% num_heads
).
a Dropout layer on attn_output_weights. Default: 0.0.
add bias as module parameter. Default: True.
add bias to the key and value sequences at dim=0.
add a new batch of zeros to the key and value sequences at dim=1.
total number of features in key. Default: NULL
total number of features in value. Default: NULL
. Note: if kdim
and vdim are NULL
, they will be set to embed_dim such that query, key,
and value have the same number of features.
if TRUE
then the input and output tensors are \((N,
S, E)\) instead of \((S, N, E)\), where N is the batch size, S is the
sequence length, and E is the embedding dimension.
Inputs:
query: \((L, N, E)\) where L is the target sequence length, N is the
batch size, E is the embedding dimension. (but see the batch_first
argument)
key: \((S, N, E)\), where S is the source sequence length, N is the
batch size, E is the embedding dimension. (but see the batch_first
argument)
value: \((S, N, E)\) where S is the source sequence length,
N is the batch size, E is the embedding dimension. (but see the
batch_first
argument)
key_padding_mask: \((N, S)\) where N is the batch size, S is the source
sequence length. If a ByteTensor is provided, the non-zero positions will
be ignored while the position with the zero positions will be unchanged. If
a BoolTensor is provided, the positions with the value of True
will be
ignored while the position with the value of False
will be unchanged.
attn_mask: 2D mask \((L, S)\) where L is the target sequence length, S
is the source sequence length. 3D mask \((N*num_heads, L, S)\) where N is
the batch size, L is the target sequence length, S is the source sequence
length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not
allowed to attend while the zero positions will be unchanged. If a
BoolTensor is provided, positions with True
are not allowed to attend
while False
values will be unchanged. If a FloatTensor is provided, it
will be added to the attention weight.
Outputs:
attn_output: \((L, N, E)\) where L is the target sequence length, N is
the batch size, E is the embedding dimension. (but see the batch_first
argument)
attn_output_weights:
if avg_weights
is TRUE
(the default), the output attention
weights are averaged over the attention heads, giving a tensor of shape
\((N, L, S)\) where N is the batch size, L is the target sequence
length, S is the source sequence length.
if avg_weights
is FALSE
, the attention weight tensor is output
as-is, with shape \((N, H, L, S)\), where H is the number of attention
heads.
$$ \mbox{MultiHead}(Q, K, V) = \mbox{Concat}(head_1,\dots,head_h)W^O \mbox{where} head_i = \mbox{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
if (torch_is_installed()) {
if (FALSE) {
multihead_attn <- nn_multihead_attention(embed_dim, num_heads)
out <- multihead_attn(query, key, value)
attn_output <- out[[1]]
attn_output_weights <- out[[2]]
}
}
Run the code above in your browser using DataLab