Implementation of different frameworks for transformer multi attention (tensorflow + Python)

Xixi Moyo 2020-11-19 03:20:14
implementation different frameworks transformer multi


Multiple attention can be described in the following diagram :

1、 Use pytorch The implementation of the library

torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)

The parameters are described as follows :

  • embed_dim: Final output K、Q、V The dimensions of the matrix , This dimension needs to be the same as that of the word vector

  • num_heads: Set the number of bull attention . If set to 1, So just use one set of attention . If set to another value , that num_heads The value of needs to be able to be embed_dim to be divisible by

  • dropout: This dropout Add to attention score Back

Now let's explain , Why?  num_heads The value of needs to be able to be embed_dim to be divisible by . This is to divide the hidden vector length of words equally into each group , In this way, multiple groups of attention can also be put into a matrix , So that we can compute the attention of multiple heads in parallel .

Definition  MultiheadAttention  After the object of , The parameters passed in when calling are as follows .

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
  • query: Corresponding to Key matrix , The shape is (L,N,E) . among L Is the length of the output sequence ,N yes batch size,E It's the dimension of the word vector

  • key: Corresponding to Key matrix , The shape is (S,N,E) . among S Is the length of the input sequence ,N yes batch size,E It's the dimension of the word vector

  • value: Corresponding to Value matrix , The shape is (S,N,E) . among S Is the length of the input sequence ,N yes batch size,E It's the dimension of the word vector

  • key_padding_mask: If this parameter is provided , Then calculate attention score when , Ignore Key Some of the matrix padding Elements , Do not participate in calculation attention. The shape is (N,S). among N yes batch size,S Is the length of the input sequence .

    • If key_padding_mask yes ByteTensor, So no 0 The location of the element is ignored
    • If key_padding_mask yes BoolTensor, that  True The corresponding position will be ignored
  • attn_mask: When calculating the output , Ignore some places . The shape can be 2D  (L,S), perhaps 3D (N∗numheads,L,S). among L Is the length of the output sequence ,S Is the length of the input sequence ,N yes batch size.

    • If attn_mask yes ByteTensor, So no 0 The location of the element is ignored
    • If attn_mask yes BoolTensor, that  True The corresponding position will be ignored

It should be noted that : In practice, ,K、V The sequence length of a matrix is the same , and Q The sequence length of a matrix can be different .

This happens in : In the decoder part Encoder-Decoder Attention Layer ,Q The matrix comes from the lower layer of the decoder , and K、V The matrix is the output from the encoder .

Code example :

## nn.MultiheadAttention Enter the first 0 Wei Wei length
# batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension 
query = torch.rand(12,64,300)
# batch_size by 64, Yes 10 Word , Every word of Key The vector is 300 dimension 
key = torch.rand(10,64,300)
# batch_size by 64, Yes 10 Word , Every word of Value The vector is 300 dimension 
value= torch.rand(10,64,300)
embed_dim = 299
num_heads = 1
# The output is (attn_output, attn_output_weights)
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output = multihead_attn(query, key, value)[0]
# output: torch.Size([12, 64, 300])
# batch_size by 64, Yes 12 Word , The vector of each word is 300 dimension 
print(attn_output.shape)

2、 Manual calculation of long attention

stay PyTorch Provided MultiheadAttention in , The first 1 Dimension is the length of a sentence , The first 2 Weishi batch size. Here, our code implementation , The first 1 Weishi batch size, The first 2 Dimension is the length of a sentence . The code also includes : How to use matrix to realize parallel computation of multiple groups of attention . There are detailed comments and instructions in the code .

class MultiheadAttention(nn.Module):
# n_heads: The number of bull attention 
# hid_dim: The vector dimension of the output of each word 
def __init__(self, hid_dim, n_heads, dropout):
super(MultiheadAttention, self).__init__()
self.hid_dim = hid_dim
self.n_heads = n_heads
# mandatory hid_dim You have to divide h
assert hid_dim % n_heads == 0
# Definition W_q matrix 
self.w_q = nn.Linear(hid_dim, hid_dim)
# Definition W_k matrix 
self.w_k = nn.Linear(hid_dim, hid_dim)
# Definition W_v matrix 
self.w_v = nn.Linear(hid_dim, hid_dim)
self.fc = nn.Linear(hid_dim, hid_dim)
self.do = nn.Dropout(dropout)
# The zoom 
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads]))
def forward(self, query, key, value, mask=None):
# K: [64,10,300], batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension 
# V: [64,10,300], batch_size by 64, Yes 10 Word , Every word of Query The vector is 300 dimension 
# Q: [64,12,300], batch_size by 64, Yes 10 Word , Every word of Query The vector is 300 dimension 
bsz = query.shape[0]
Q = self.w_q(query)
K = self.w_k(key)
V = self.w_v(value)
# Here is the K Q V The matrix is split into groups of attention , Become a 4 A matrix of dimensions 
# The last dimension is to use self.hid_dim // self.n_heads To get , The vector length of each group of attention , Every head The vector length of is :300/6=50
# 64 Express batch size,6 Express 6 Group attention ,10 Express 10 word ,50 The vector length of the words representing each group's attention 
# K: [64,10,300] Split multiple sets of attention -> [64,10,6,50] Transpose to get -> [64,6,10,50]
# V: [64,10,300] Split multiple sets of attention -> [64,10,6,50] Transpose to get -> [64,6,10,50]
# Q: [64,12,300] Split multiple sets of attention -> [64,12,6,50] Transpose to get -> [64,6,12,50]
# Transpose is to put the amount of attention 6 Put it in front , hold 10 and 50 Put it in the back , It is convenient to calculate 
Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
self.n_heads).permute(0, 2, 1, 3)
K = K.view(bsz, -1, self.n_heads, self.hid_dim //
self.n_heads).permute(0, 2, 1, 3)
V = V.view(bsz, -1, self.n_heads, self.hid_dim //
self.n_heads).permute(0, 2, 1, 3)
# The first 1 Step :Q multiply K The transpose , Divide scale
# [64,6,12,50] * [64,6,50,10] = [64,6,12,10]
# attention:[64,6,12,10]
attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
# hold mask Not empty , Then put mask by 0 The location of attention The score is set to -1e10
if mask is not None:
attention = attention.masked_fill(mask == 0, -1e10)
# The first 2 Step : Calculate the result of the previous step softmax, after dropout, obtain attention.
# Be careful , Here's the last dimension softmax, That is, in the dimension of the input sequence softmax
# attention: [64,6,12,10]
attention = self.do(torch.softmax(attention, dim=-1))
# The third step ,attention Results and V Multiply , The result of long attention 
# [64,6,12,10] * [64,6,10,50] = [64,6,12,50]
# x: [64,6,12,50]
x = torch.matmul(attention, V)
# because query Yes 12 Word , So the 12 Put it in front , hold 5 and 60 Put it in the back , To facilitate the following splicing of multiple groups of results 
# x: [64,6,12,50] Transposition -> [64,12,6,50]
x = x.permute(0, 2, 1, 3).contiguous()
# The matrix transformation here is : Put together the results of multiple sets of attention 
# The end result is [64,12,300]
# x: [64,12,6,50] -> [64,12,300]
x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
x = self.fc(x)
return x
# batch_size by 64, Yes 12 Word , Every word of Query The vector is 300 dimension 
query = torch.rand(64, 12, 300)
# batch_size by 64, Yes 12 Word , Every word of Key The vector is 300 dimension 
key = torch.rand(64, 10, 300)
# batch_size by 64, Yes 10 Word , Every word of Value The vector is 300 dimension 
value = torch.rand(64, 10, 300)
attention = MultiheadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
## output: torch.Size([64, 12, 300])
print(output.shape)

3、tensorflow Realized long attention

def _multiheadAttention(rawKeys, queries, keys, numUnits=None, causality=False, scope="multiheadAttention"):
# rawKeys The function of is to calculate mask Time use , because keys Yes, with position embedding Of , There is no such thing as padding by 0 Value 
# numUnits = 50

numHeads = 6
keepProb = 1
if numUnits is None: # If there is no input value , Go straight to the last dimension of the data , namely embedding size.
numUnits = queries.get_shape().as_list()[-1] #300
# tf.layers.dense You can do multidimensional tensor Nonlinear mapping of data , In the calculation self-Attention when , It's important to map these three values nonlinearly ,
# In fact, this step is in the paper Multi-Head Attention The steps of weight mapping for the segmented data in , Here we map first and then divide , In principle, it's the same .
# Q, K, V The dimensions of are all [batch_size, sequence_length, embedding_size]
Q = tf.layers.dense(queries, numUnits, activation=tf.nn.relu) # [64,10,300]
K = tf.layers.dense(keys, numUnits, activation=tf.nn.relu) # [64,10,300]
V = tf.layers.dense(keys, numUnits, activation=tf.nn.relu) # [64,10,300]
# Divide the data into the last dimension num_heads individual , And then we put together the first dimension 
# Q, K, V The dimensions of are all [batch_size * numHeads, sequence_length, embedding_size/numHeads]
Q_ = tf.concat(tf.split(Q, numHeads, axis=-1), axis=0) # [64*6,10,50]
K_ = tf.concat(tf.split(K, numHeads, axis=-1), axis=0) # [64*6,10,50]
V_ = tf.concat(tf.split(V, numHeads, axis=-1), axis=0) # [64*6,10,50]
# Calculation keys and queries Dot product between , dimension [batch_size * numHeads, queries_len, key_len], The last two dimensions are queries and keys The sequence length of 
similary = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # [64*6,10,10]
# Scale the calculated point product , The length of the root vector divided by 
scaledSimilary = similary / (K_.get_shape().as_list()[-1] ** 0.5) # [64*6,10,10]
# There will be... In the sequence we enter padding This kind of filler , This word should not help the end result , In principle padding It's all input 0 when ,
# The calculated weight should be the same 0, But in transformer Position vector is introduced in , When you add it to the position vector , Its value is not 0 了 , So add the position vector 
# Before , We need to mask by 0. Although in queries There are also such filler words in , But in principle, the results of the model are related to the input , And in self-Attention in 
# queryies = keys, So as long as one party is 0, The calculated weight is 0.
# Specific about key mask You can see here : https://github.com/Kyubyong/transformer/issues/3
# utilize tf,tile Tensor expansion , dimension [batch_size * numHeads, keys_len] keys_len = keys The sequence length of 
# Add the values of the vectors in each time series to get the average 
# rawkKeys:[64,10,300]
keyMasks = tf.sign(tf.abs(tf.reduce_sum(rawKeys, axis=-1))) # dimension [batch_size, time_step] [64,10]
#tf.sign() Yes, it will <0 The value of a -1, Greater than 0 The value of a 1, be equal to 0 The value of a 0
keyMasks = tf.tile(keyMasks, [numHeads, 1]) # [64*6,10]
# Find out padding The location of 
# Add a dimension , And expand , Get dimensions [batch_size * numHeads, queries_len, keys_len]
keyMasks = tf.tile(tf.expand_dims(keyMasks, 1), [1, tf.shape(queries)[1], 1]) # [64*6,10,10] 10 A for 1 Group 
print(keyMasks.shape)
# tf.ones_like The generating elements are all 1, Dimensions and scaledSimilary same tensor, And then we get the value of negative infinity 
paddings = tf.ones_like(scaledSimilary) * (-2 ** (32 + 1)) [64*6,10,10]
# tf.where(condition, x, y),condition The element in is bool value , Which corresponds to True use x Element replacement in , Corresponding False use y Element replacement in 
# therefore condition,x,y The dimensions of are the same . The following is keyMasks The value of 0 Just use paddings Replace the value in 
maskedSimilary = tf.where(tf.equal(keyMasks, 0), paddings, scaledSimilary) # dimension [batch_size * numHeads, queries_len, key_len]
# When calculating the current word , Just consider the above , Don't consider the following , Appear in the Transformer Decoder in . In text categorization , It can be used only Transformer Encoder.
# Decoder It's a generative model , It is mainly used in language generation 
if causality:
diagVals = tf.ones_like(maskedSimilary[0, :, :]) # [queries_len, keys_len]
tril = tf.contrib.linalg.LinearOperatorTriL(diagVals).to_dense() # [queries_len, keys_len]
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(maskedSimilary)[0], 1, 1]) # [batch_size * numHeads, queries_len, keys_len]

paddings = tf.ones_like(masks) * (-2 ** (32 + 1))
maskedSimilary = tf.where(tf.equal(masks, 0), paddings, maskedSimilary) # [batch_size * numHeads, queries_len, keys_len]
# adopt softmax Calculate the weight coefficient , dimension [batch_size * numHeads, queries_len, keys_len]
weights = tf.nn.softmax(maskedSimilary)
# Weighted sum to get the output value , dimension [batch_size * numHeads, sequence_length, embedding_size/numHeads]
outputs = tf.matmul(weights, V_)
# Will be long Attention The calculated output reconstitutes the original dimension [batch_size, sequence_length, embedding_size]
outputs = tf.concat(tf.split(outputs, numHeads, axis=0), axis=2)
outputs = tf.nn.dropout(outputs, keep_prob=keepProb)
# For each subLayers Establish residual link , namely H(x) = F(x) + x
outputs += queries
# normalization layer 
#outputs = self._layerNormalization(outputs)
return outputs

Input is :self.embeddedWords = self.wordEmbedded + self.positionEmbedded, That is, word embedding + Position insertion

Or to pytorch For example, the input dimension of :self.wordEmbedded Dimensions [64,10,300] self.positionEmbedded The dimension of is [64,10,300]

When you use it, it's :

multiHeadAtt = self._multiheadAttention(rawKeys=self.wordEmbedded, queries=self.embeddedWords,
keys=self.embeddedWords)

  for example :( This simplifies the input )

wordEmbedded = tf.Variable(np.ones((64,10,300)))
positionEmbedded = tf.Variable(np.ones((64,10,300)))
embeddedWords = wordEmbedded + positionEmbedded
multiHeadAtt = _multiheadAttention(rawKeys=wordEmbedded, queries=embeddedWords, keys=embeddedWords, numUnits=300)

It should be noted that ,rawkeys It's for word embedding , Because with the addition of position embedding embeddedWords Of mask It's covered by location embedding , You can't find the need mask Location. .

Above pytorch The example of is actually corresponding to if causality The following code , Because in the coding phase :Q=K=V( The dimensions between them are the same ), In the decoding phase ,Q Input from the decoding phase , It can be [64,12,300], and K and V The output from the encoder , The shapes are [64,10,300]. That is to say Encoder-Decoder Attention. And when QKV When they all come from the same input , That is to say self attention.

 

Reference resources :https://mp.weixin.qq.com/s/cJqhESxTMy5cfj0EXh9s4w 

版权声明
本文为[Xixi Moyo]所创,转载请带上原文链接,感谢

  1. 利用Python爬虫获取招聘网站职位信息
  2. Using Python crawler to obtain job information of recruitment website
  3. Several highly rated Python libraries arrow, jsonpath, psutil and tenacity are recommended
  4. Python装饰器
  5. Python实现LDAP认证
  6. Python decorator
  7. Implementing LDAP authentication with Python
  8. Vscode configures Python development environment!
  9. In Python, how dare you say you can't log module? ️
  10. 我收藏的有关Python的电子书和资料
  11. python 中 lambda的一些tips
  12. python中字典的一些tips
  13. python 用生成器生成斐波那契数列
  14. python脚本转pyc踩了个坑。。。
  15. My collection of e-books and materials about Python
  16. Some tips of lambda in Python
  17. Some tips of dictionary in Python
  18. Using Python generator to generate Fibonacci sequence
  19. The conversion of Python script to PyC stepped on a pit...
  20. Python游戏开发,pygame模块,Python实现扫雷小游戏
  21. Python game development, pyGame module, python implementation of minesweeping games
  22. Python实用工具,email模块,Python实现邮件远程控制自己电脑
  23. Python utility, email module, python realizes mail remote control of its own computer
  24. 毫无头绪的自学Python,你可能连门槛都摸不到!【最佳学习路线】
  25. Python读取二进制文件代码方法解析
  26. Python字典的实现原理
  27. Without a clue, you may not even touch the threshold【 Best learning route]
  28. Parsing method of Python reading binary file code
  29. Implementation principle of Python dictionary
  30. You must know the function of pandas to parse JSON data - JSON_ normalize()
  31. Python实用案例,私人定制,Python自动化生成爱豆专属2021日历
  32. Python practical case, private customization, python automatic generation of Adu exclusive 2021 calendar
  33. 《Python实例》震惊了,用Python这么简单实现了聊天系统的脏话,广告检测
  34. "Python instance" was shocked and realized the dirty words and advertisement detection of the chat system in Python
  35. Convolutional neural network processing sequence for Python deep learning
  36. Python data structure and algorithm (1) -- enum type enum
  37. 超全大厂算法岗百问百答(推荐系统/机器学习/深度学习/C++/Spark/python)
  38. 【Python进阶】你真的明白NumPy中的ndarray吗?
  39. All questions and answers for algorithm posts of super large factories (recommended system / machine learning / deep learning / C + + / spark / Python)
  40. [advanced Python] do you really understand ndarray in numpy?
  41. 【Python进阶】Python进阶专栏栏主自述:不忘初心,砥砺前行
  42. [advanced Python] Python advanced column main readme: never forget the original intention and forge ahead
  43. python垃圾回收和缓存管理
  44. java调用Python程序
  45. java调用Python程序
  46. Python常用函数有哪些?Python基础入门课程
  47. Python garbage collection and cache management
  48. Java calling Python program
  49. Java calling Python program
  50. What functions are commonly used in Python? Introduction to Python Basics
  51. Python basic knowledge
  52. Anaconda5.2 安装 Python 库(MySQLdb)的方法
  53. Python实现对脑电数据情绪分析
  54. Anaconda 5.2 method of installing Python Library (mysqldb)
  55. Python implements emotion analysis of EEG data
  56. Master some advanced usage of Python in 30 seconds, which makes others envy it
  57. python爬取百度图片并对图片做一系列处理
  58. Python crawls Baidu pictures and does a series of processing on them
  59. python链接mysql数据库
  60. Python link MySQL database