pytorch It's also realized by itself transformer Model of , differ huggingface Or somewhere else ,pytorch Of mask Parameters are more difficult to understand （ Even with documentation ）, Here are some supplements and explanations .（ By the way , there transformer It needs to be done on its own position embedding Of , Don't be happy, just run the data ）
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) >>> src = torch.rand((10, 32, 512)) >>> tgt = torch.rand((20, 32, 512)) >>> out = transformer_model(src, tgt) # It didn't come true position embedding , You need to do it yourself mask Mechanism . Otherwise it's not what you think transformer
First of all, let's look at the parameters of the official website
The biggest difference is *mask_ and *_key_padding_mask,_ as for * yes src still tgt,memory, It doesn't matter , The module appears in encoder, Namely src, Appear in the decoder, Namely tgt,decoder Every block The second layer of and encoder do cross attention When , Namely memory.
*mask Corresponding API yes attn_mask,*_key_padding_mask Corresponding API yes key_padding_mask
Let's see. torch/nn/modules/activation.py among MultiheadAttention modular For this 2 individual API The explanation of ：
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """
among ,N yes batch size Size ,L Is the length of the target sequence (the target sequence length),S Is the length of the source sequence (the source sequence length). This module will appear in the picture above 3 An orange area , therefore the target sequence It doesn't necessarily mean decoder Input sequence ,the source sequence Not necessarily encoder Input sequence .
A better understanding is ,target sequence It's for bulls attention among q（ Inquire about ） Sequence ,source sequence representative k（ Key value ） and v（ value ） Sequence . for example , When decoder Doing it self-attention When ,target sequence and source sequence It's all about itself , So at this time L=S, All are decoder The length of the encoded sequence .
Here is a simple example ：
Now there is a batch,batch_size = 3, The length is 4,token The form of expression is as follows ：
[ [‘a’,'b','c','<PAD>'], [‘a’,'b','c','d'], [‘a’,'b','<PAD>','<PAD>'] ]
Now suppose you're going to do it self-attention The calculation of （ Can be in encoder, It can also be in decoder）, So take the third line of data as an example ,‘a’ Doing it qkv When calculating , Will see 'b','<PAD>','<PAD>', But we don't want to ‘a’ notice '<PAD>', Because they are meaningless in themselves , therefore , need key_padding_mask Cover them .
key_padding_mask The shape and size of （N,S）, So here's an example ,key_padding_mask In the following form ,key_padding_mask.shape = （3,4）：
[ [False, False, False, True], [False, False, False, False], [False, False, True, True] ]
It's worth noting that ,key_padding_mask It's essentially a cover up key The value of this position （ Set up 0）, however <PAD> token In itself , Also can do qkv Of calculation , Take the third position of the third row of data as an example , its q yes <PAD> Of embedding,k and v Each of them is the first ‘a’ And the second one ‘b’, It will also output a embedding.
So your model training is transformer final output Calculation loss When , You also need to specify ignoreindex=pad_index. Take the third line of data as an example , Its supervisory signal is [3205,1890,0,0],pad_index=0 . In this way , Even in <PAD> Of transformer It's going to be crazy and meaningful position do qkv, Will be output embedding, But we don't count it loss, Let it do all kinds of demons .
At first I saw 2 individual mask Parameter time , I'm also confused , And their shape It's not the same .attn_mask Where on earth is it used ？
decoder Doing it self-attention When , Each position is different from encoder, He can only see the above information .key_padding_mask Of shape by (batch_size, source_length), It means that every position has query, What he saw went through key_padding_mask It's all the same after （ Even though he can do it batch Each row of data mask Is not the same ）, This does not meet the requirements of the following modules ：
decoder Of mask Multi attention module
What's needed here mask as follows ：
Yellow is the visible part , Purple is the invisible part , Different locations need mask It's not the same part
and pytorch Of nn.Transformer We already have functions that help us implement ：
def generate_square_subsequent_mask(self, sz: int) -> Tensor: r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask
Or the example above , Take the first line of data ['a','b','c','<PAD>'], For example （ Suppose we're using decoder Make a production , Research block The first floor of layer That is to say self-attention）, here ：
Think about it attn_mask The shape requirements of ,2 When I was young, I was （L,S）,3 When I was young, I was （N*num_heads, L, S）. here , because qkv It's all the same sequence （decoder The sequence below ） therefore L=S; And because for batch For each row of data , their mask The mechanism is the same , That is to say i Value of position , You can only see the information above , So our attn_mask Just two-dimensional , The internal implementation will put mask The matrix broadcasts to batch In each row of data ：
generally speaking , Unless you need magic transformer, For example, let different heads see different information , Otherwise, the two-dimensional matrix is enough .
I think it's best to use it according to the above convention , actually ,2 individual mask Working together on the same model , Must use attn_mask Instead of key_padding_mask hold <PAD> Cover up , OK? ？ Certainly. , It just increases your workload .
Let's get down to business
This article is from WeChat official account. - Deep learning of natural language processing （zenRRan）
The source and reprint of the original text are detailed in the text , If there is any infringement , Please contact the firstname.lastname@example.org Delete .
Original publication time ： 2021-04-04
Participation of this paper Tencent cloud media sharing plan , You are welcome to join us , share .