Multiple attention can be described in the following diagram ：
1、 Use pytorch The implementation of the library
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
The parameters are described as follows ：

embed_dim： Final output K、Q、V The dimensions of the matrix , This dimension needs to be the same as that of the word vector

num_heads： Set the number of bull attention . If set to 1, So just use one set of attention . If set to another value , that num_heads The value of needs to be able to be embed_dim to be divisible by

dropout： This dropout Add to attention score Back
Now let's explain , Why? num_heads The value of needs to be able to be embed_dim to be divisible by . This is to divide the hidden vector length of words equally into each group , In this way, multiple groups of attention can also be put into a matrix , So that we can compute the attention of multiple heads in parallel .
Definition MultiheadAttention
After the object of , The parameters passed in when calling are as follows .
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)

query： Corresponding to Key matrix , The shape is (L,N,E) . among L Is the length of the output sequence ,N yes batch size,E It's the dimension of the word vector

key： Corresponding to Key matrix , The shape is (S,N,E) . among S Is the length of the input sequence ,N yes batch size,E It's the dimension of the word vector

value： Corresponding to Value matrix , The shape is (S,N,E) . among S Is the length of the input sequence ,N yes batch size,E It's the dimension of the word vector

key_padding_mask： If this parameter is provided , Then calculate attention score when , Ignore Key Some of the matrix padding Elements , Do not participate in calculation attention. The shape is (N,S). among N yes batch size,S Is the length of the input sequence .
 If key_padding_mask yes ByteTensor, So no 0 The location of the element is ignored
 If key_padding_mask yes BoolTensor, that True The corresponding position will be ignored

attn_mask： When calculating the output , Ignore some places . The shape can be 2D (L,S), perhaps 3D (N∗numheads,L,S). among L Is the length of the output sequence ,S Is the length of the input sequence ,N yes batch size.
 If attn_mask yes ByteTensor, So no 0 The location of the element is ignored
 If attn_mask yes BoolTensor, that True The corresponding position will be ignored
It should be noted that ： In practice, ,K、V The sequence length of a matrix is the same , and Q The sequence length of a matrix can be different .
This happens in ： In the decoder part EncoderDecoder Attention
Layer ,Q The matrix comes from the lower layer of the decoder , and K、V The matrix is the output from the encoder .
Code example ：
## nn.MultiheadAttention Enter the first 0 Wei Wei length # batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension query = torch.rand(12,64,300) # batch_size by 64, Yes 10 Word , Every word of Key The vector is 300 dimension key = torch.rand(10,64,300) # batch_size by 64, Yes 10 Word , Every word of Value The vector is 300 dimension value= torch.rand(10,64,300) embed_dim = 299 num_heads = 1 # The output is (attn_output, attn_output_weights) multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) attn_output = multihead_attn(query, key, value)[0] # output: torch.Size([12, 64, 300]) # batch_size by 64, Yes 12 Word , The vector of each word is 300 dimension print(attn_output.shape)
2、 Manual calculation of long attention
stay PyTorch Provided MultiheadAttention in , The first 1 Dimension is the length of a sentence , The first 2 Weishi batch size. Here, our code implementation , The first 1 Weishi batch size, The first 2 Dimension is the length of a sentence . The code also includes ： How to use matrix to realize parallel computation of multiple groups of attention . There are detailed comments and instructions in the code .
class MultiheadAttention(nn.Module): # n_heads： The number of bull attention # hid_dim： The vector dimension of the output of each word def __init__(self, hid_dim, n_heads, dropout): super(MultiheadAttention, self).__init__() self.hid_dim = hid_dim self.n_heads = n_heads # mandatory hid_dim You have to divide h assert hid_dim % n_heads == 0 # Definition W_q matrix self.w_q = nn.Linear(hid_dim, hid_dim) # Definition W_k matrix self.w_k = nn.Linear(hid_dim, hid_dim) # Definition W_v matrix self.w_v = nn.Linear(hid_dim, hid_dim) self.fc = nn.Linear(hid_dim, hid_dim) self.do = nn.Dropout(dropout) # The zoom self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])) def forward(self, query, key, value, mask=None): # K: [64,10,300], batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension # V: [64,10,300], batch_size by 64, Yes 10 Word , Every word of Query The vector is 300 dimension # Q: [64,12,300], batch_size by 64, Yes 10 Word , Every word of Query The vector is 300 dimension bsz = query.shape[0] Q = self.w_q(query) K = self.w_k(key) V = self.w_v(value) # Here is the K Q V The matrix is split into groups of attention , Become a 4 A matrix of dimensions # The last dimension is to use self.hid_dim // self.n_heads To get , The vector length of each group of attention , Every head The vector length of is ：300/6=50 # 64 Express batch size,6 Express 6 Group attention ,10 Express 10 word ,50 The vector length of the words representing each group's attention # K: [64,10,300] Split multiple sets of attention > [64,10,6,50] Transpose to get > [64,6,10,50] # V: [64,10,300] Split multiple sets of attention > [64,10,6,50] Transpose to get > [64,6,10,50] # Q: [64,12,300] Split multiple sets of attention > [64,12,6,50] Transpose to get > [64,6,12,50] # Transpose is to put the amount of attention 6 Put it in front , hold 10 and 50 Put it in the back , It is convenient to calculate Q = Q.view(bsz, 1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3) K = K.view(bsz, 1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3) V = V.view(bsz, 1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3) # The first 1 Step ：Q multiply K The transpose , Divide scale # [64,6,12,50] * [64,6,50,10] = [64,6,12,10] # attention：[64,6,12,10] attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale # hold mask Not empty , Then put mask by 0 The location of attention The score is set to 1e10 if mask is not None: attention = attention.masked_fill(mask == 0, 1e10) # The first 2 Step ： Calculate the result of the previous step softmax, after dropout, obtain attention. # Be careful , Here's the last dimension softmax, That is, in the dimension of the input sequence softmax # attention: [64,6,12,10] attention = self.do(torch.softmax(attention, dim=1)) # The third step ,attention Results and V Multiply , The result of long attention # [64,6,12,10] * [64,6,10,50] = [64,6,12,50] # x: [64,6,12,50] x = torch.matmul(attention, V) # because query Yes 12 Word , So the 12 Put it in front , hold 5 and 60 Put it in the back , To facilitate the following splicing of multiple groups of results # x: [64,6,12,50] Transposition > [64,12,6,50] x = x.permute(0, 2, 1, 3).contiguous() # The matrix transformation here is ： Put together the results of multiple sets of attention # The end result is [64,12,300] # x: [64,12,6,50] > [64,12,300] x = x.view(bsz, 1, self.n_heads * (self.hid_dim // self.n_heads)) x = self.fc(x) return x # batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension query = torch.rand(64, 12, 300) # batch_size by 64, Yes 12 Word , Every word of Key The vector is 300 dimension key = torch.rand(64, 10, 300) # batch_size by 64, Yes 10 Word , Every word of Value The vector is 300 dimension value = torch.rand(64, 10, 300) attention = MultiheadAttention(hid_dim=300, n_heads=6, dropout=0.1) output = attention(query, key, value) ## output: torch.Size([64, 12, 300]) print(output.shape)
3、tensorflow Realized long attention
def _multiheadAttention(rawKeys, queries, keys, numUnits=None, causality=False, scope="multiheadAttention"): # rawKeys The function of is to calculate mask Time use , because keys Yes, with position embedding Of , There is no such thing as padding by 0 Value # numUnits = 50 numHeads = 6 keepProb = 1 if numUnits is None: # If there is no input value , Go straight to the last dimension of the data , namely embedding size. numUnits = queries.get_shape().as_list()[1] #300 # tf.layers.dense You can do multidimensional tensor Nonlinear mapping of data , In the calculation selfAttention when , It's important to map these three values nonlinearly , # In fact, this step is in the paper MultiHead Attention The steps of weight mapping for the segmented data in , Here we map first and then divide , In principle, it's the same . # Q, K, V The dimensions of are all [batch_size, sequence_length, embedding_size] Q = tf.layers.dense(queries, numUnits, activation=tf.nn.relu) # [64,10,300] K = tf.layers.dense(keys, numUnits, activation=tf.nn.relu) # [64,10,300] V = tf.layers.dense(keys, numUnits, activation=tf.nn.relu) # [64,10,300] # Divide the data into the last dimension num_heads individual , And then we put together the first dimension # Q, K, V The dimensions of are all [batch_size * numHeads, sequence_length, embedding_size/numHeads] Q_ = tf.concat(tf.split(Q, numHeads, axis=1), axis=0) # [64*6,10,50] K_ = tf.concat(tf.split(K, numHeads, axis=1), axis=0) # [64*6,10,50] V_ = tf.concat(tf.split(V, numHeads, axis=1), axis=0) # [64*6,10,50] # Calculation keys and queries Dot product between , dimension [batch_size * numHeads, queries_len, key_len], The last two dimensions are queries and keys The sequence length of similary = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # [64*6,10,10] # Scale the calculated point product , The length of the root vector divided by scaledSimilary = similary / (K_.get_shape().as_list()[1] ** 0.5) # [64*6,10,10] # There will be... In the sequence we enter padding This kind of filler , This word should not help the end result , In principle padding It's all input 0 when , # The calculated weight should be the same 0, But in transformer Position vector is introduced in , When you add it to the position vector , Its value is not 0 了 , So add the position vector # Before , We need to mask by 0. Although in queries There are also such filler words in , But in principle, the results of the model are related to the input , And in selfAttention in # queryies = keys, So as long as one party is 0, The calculated weight is 0. # Specific about key mask You can see here ： https://github.com/Kyubyong/transformer/issues/3 # utilize tf,tile Tensor expansion , dimension [batch_size * numHeads, keys_len] keys_len = keys The sequence length of # Add the values of the vectors in each time series to get the average # rawkKeys:[64,10,300] keyMasks = tf.sign(tf.abs(tf.reduce_sum(rawKeys, axis=1))) # dimension [batch_size, time_step] [64,10] #tf.sign() Yes, it will <0 The value of a 1, Greater than 0 The value of a 1, be equal to 0 The value of a 0 keyMasks = tf.tile(keyMasks, [numHeads, 1]) # [64*6,10] # Find out padding The location of # Add a dimension , And expand , Get dimensions [batch_size * numHeads, queries_len, keys_len] keyMasks = tf.tile(tf.expand_dims(keyMasks, 1), [1, tf.shape(queries)[1], 1]) # [64*6,10,10] 10 A for 1 Group print(keyMasks.shape) # tf.ones_like The generating elements are all 1, Dimensions and scaledSimilary same tensor, And then we get the value of negative infinity paddings = tf.ones_like(scaledSimilary) * (2 ** (32 + 1)) [64*6,10,10] # tf.where(condition, x, y),condition The element in is bool value , Which corresponds to True use x Element replacement in , Corresponding False use y Element replacement in # therefore condition,x,y The dimensions of are the same . The following is keyMasks The value of 0 Just use paddings Replace the value in maskedSimilary = tf.where(tf.equal(keyMasks, 0), paddings, scaledSimilary) # dimension [batch_size * numHeads, queries_len, key_len] # When calculating the current word , Just consider the above , Don't consider the following , Appear in the Transformer Decoder in . In text categorization , It can be used only Transformer Encoder. # Decoder It's a generative model , It is mainly used in language generation if causality: diagVals = tf.ones_like(maskedSimilary[0, :, :]) # [queries_len, keys_len] tril = tf.contrib.linalg.LinearOperatorTriL(diagVals).to_dense() # [queries_len, keys_len] masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(maskedSimilary)[0], 1, 1]) # [batch_size * numHeads, queries_len, keys_len] paddings = tf.ones_like(masks) * (2 ** (32 + 1)) maskedSimilary = tf.where(tf.equal(masks, 0), paddings, maskedSimilary) # [batch_size * numHeads, queries_len, keys_len] # adopt softmax Calculate the weight coefficient , dimension [batch_size * numHeads, queries_len, keys_len] weights = tf.nn.softmax(maskedSimilary) # Weighted sum to get the output value , dimension [batch_size * numHeads, sequence_length, embedding_size/numHeads] outputs = tf.matmul(weights, V_) # Will be long Attention The calculated output reconstitutes the original dimension [batch_size, sequence_length, embedding_size] outputs = tf.concat(tf.split(outputs, numHeads, axis=0), axis=2) outputs = tf.nn.dropout(outputs, keep_prob=keepProb) # For each subLayers Establish residual link , namely H(x) = F(x) + x outputs += queries # normalization layer #outputs = self._layerNormalization(outputs) return outputs
Input is ：self.embeddedWords = self.wordEmbedded + self.positionEmbedded, That is, word embedding + Position insertion
Or to pytorch For example, the input dimension of ：self.wordEmbedded Dimensions [64,10,300] self.positionEmbedded The dimension of is [64,10,300]
When you use it, it's ：
multiHeadAtt = self._multiheadAttention(rawKeys=self.wordEmbedded, queries=self.embeddedWords,
keys=self.embeddedWords)
for example ：（ This simplifies the input ）
wordEmbedded = tf.Variable(np.ones((64,10,300))) positionEmbedded = tf.Variable(np.ones((64,10,300))) embeddedWords = wordEmbedded + positionEmbedded multiHeadAtt = _multiheadAttention(rawKeys=wordEmbedded, queries=embeddedWords, keys=embeddedWords, numUnits=300)
It should be noted that ,rawkeys It's for word embedding , Because with the addition of position embedding embeddedWords Of mask It's covered by location embedding , You can't find the need mask Location. .
Above pytorch The example of is actually corresponding to if causality The following code , Because in the coding phase ：Q=K=V（ The dimensions between them are the same ）, In the decoding phase ,Q Input from the decoding phase , It can be [64,12,300], and K and V The output from the encoder , The shapes are [64,10,300]. That is to say EncoderDecoder Attention. And when QKV When they all come from the same input , That is to say self attention.
Reference resources ：https://mp.weixin.qq.com/s/cJqhESxTMy5cfj0EXh9s4w