[practice] Python nn.Transformer Mask understanding of

zenRRan 2021-04-08 11:54:48
practice python nn.transformer nn transformer


pytorch It's also realized by itself transformer Model of , differ huggingface Or somewhere else ,pytorch Of mask Parameters are more difficult to understand ( Even with documentation ), Here are some supplements and explanations .( By the way , there transformer It needs to be done on its own position embedding Of , Don't be happy, just run the data )

>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = torch.rand((10, 32, 512))
>>> tgt = torch.rand((20, 32, 512))
>>> out = transformer_model(src, tgt) # It didn't come true position embedding , You need to do it yourself mask Mechanism . Otherwise it's not what you think transformer

First of all, let's look at the parameters of the official website

  • src – the sequence to the encoder (required).
  • tgt – the sequence to the decoder (required).
  • src_mask – the additive mask for the src sequence (optional).
  • tgt_mask – the additive mask for the tgt sequence (optional).
  • memory_mask – the additive mask for the encoder output (optional).
  • src_key_padding_mask – the ByteTensor mask for src keys per batch (optional).
  • tgt_key_padding_mask – the ByteTensor mask for tgt keys per batch (optional).
  • memory_key_padding_mask – the ByteTensor mask for memory keys per batch (optional).

The biggest difference is *mask_ and *_key_padding_mask,_ as for * yes src still tgt,memory, It doesn't matter , The module appears in encoder, Namely src, Appear in the decoder, Namely tgt,decoder Every block The second layer of and encoder do cross attention When , Namely memory.

*mask Corresponding API yes attn_mask,*_key_padding_mask Corresponding API yes key_padding_mask

Let's see. torch/nn/modules/activation.py among MultiheadAttention modular For this 2 individual API The explanation of :

def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
  • key_padding_mask: To cover up <PAD> To avoid pad token Of embedding Input . Shape requirements :(N,S)
  • attn_mask:2 Weior 3 A matrix of dimensions . To avoid a given location embedding Input .2 The dimensional matrix shape requires :(L, S); Also support 3 Dimension matrix input , Shape requirements :(N*num_heads, L, S)

among ,N yes batch size Size ,L Is the length of the target sequence (the target sequence length),S Is the length of the source sequence (the source sequence length). This module will appear in the picture above 3 An orange area , therefore the target sequence It doesn't necessarily mean decoder Input sequence ,the source sequence Not necessarily encoder Input sequence .

A better understanding is ,target sequence It's for bulls attention among q( Inquire about ) Sequence ,source sequence representative k( Key value ) and v( value ) Sequence . for example , When decoder Doing it self-attention When ,target sequence and source sequence It's all about itself , So at this time L=S, All are decoder The length of the encoded sequence .

key_padding_mask The role of

Here is a simple example :

Now there is a batch,batch_size = 3, The length is 4,token The form of expression is as follows :

[
[‘a’,'b','c','<PAD>'],
[‘a’,'b','c','d'],
[‘a’,'b','<PAD>','<PAD>']
]

Now suppose you're going to do it self-attention The calculation of ( Can be in encoder, It can also be in decoder), So take the third line of data as an example ,‘a’ Doing it qkv When calculating , Will see 'b','<PAD>','<PAD>', But we don't want to ‘a’ notice '<PAD>', Because they are meaningless in themselves , therefore , need key_padding_mask Cover them .

key_padding_mask The shape and size of (N,S), So here's an example ,key_padding_mask In the following form ,key_padding_mask.shape = (3,4):

[
[False, False, False, True],
[False, False, False, False],
[False, False, True, True]
]

It's worth noting that ,key_padding_mask It's essentially a cover up key The value of this position ( Set up 0), however <PAD> token In itself , Also can do qkv Of calculation , Take the third position of the third row of data as an example , its q yes <PAD> Of embedding,k and v Each of them is the first ‘a’ And the second one ‘b’, It will also output a embedding.

So your model training is transformer final output Calculation loss When , You also need to specify ignoreindex=pad_index. Take the third line of data as an example , Its supervisory signal is [3205,1890,0,0],pad_index=0 . In this way , Even in <PAD> Of transformer It's going to be crazy and meaningful position do qkv, Will be output embedding, But we don't count it loss, Let it do all kinds of demons .

attn_mask The role of

At first I saw 2 individual mask Parameter time , I'm also confused , And their shape It's not the same .attn_mask Where on earth is it used ?

decoder Doing it self-attention When , Each position is different from encoder, He can only see the above information .key_padding_mask Of shape by (batch_size, source_length), It means that every position has query, What he saw went through key_padding_mask It's all the same after ( Even though he can do it batch Each row of data mask Is not the same ), This does not meet the requirements of the following modules :

decoder Of mask Multi attention module

What's needed here mask as follows :

Yellow is the visible part , Purple is the invisible part , Different locations need mask It's not the same part

and pytorch Of nn.Transformer We already have functions that help us implement :

 def generate_square_subsequent_mask(self, sz: int) -> Tensor:
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask

Or the example above , Take the first line of data ['a','b','c','<PAD>'], For example ( Suppose we're using decoder Make a production , Research block The first floor of layer That is to say self-attention), here :

  • 'a' You can see 'a'
  • 'b' You can see 'a','b'
  • 'c' You can see 'a','b','c'
  • '<PAD>' What should not be seen in theory , But as long as the overhead surveillance signal is ignore_index, Then it doesn't matter , So let him see 'a','b','c','<PAD>'

Think about it attn_mask The shape requirements of ,2 When I was young, I was (L,S),3 When I was young, I was (N*num_heads, L, S). here , because qkv It's all the same sequence (decoder The sequence below ) therefore L=S; And because for batch For each row of data , their mask The mechanism is the same , That is to say i Value of position , You can only see the information above , So our attn_mask Just two-dimensional , The internal implementation will put mask The matrix broadcasts to batch In each row of data :

generally speaking , Unless you need magic transformer, For example, let different heads see different information , Otherwise, the two-dimensional matrix is enough .

When to use key_padding_mask, When to use attn_mask?

I think it's best to use it according to the above convention , actually ,2 individual mask Working together on the same model , Must use attn_mask Instead of key_padding_mask hold <PAD> Cover up , OK? ? Certainly. , It just increases your workload .

Let's get down to business

This article is from WeChat official account. - Deep learning of natural language processing (zenRRan)

The source and reprint of the original text are detailed in the text , If there is any infringement , Please contact the yunjia_community@tencent.com Delete .

Original publication time : 2021-04-04

Participation of this paper Tencent cloud media sharing plan , You are welcome to join us , share .

版权声明
本文为[zenRRan]所创,转载请带上原文链接,感谢
https://pythonmana.com/2021/04/20210408111750715X.html

  1. Drawing Doraemon in Python
  2. Python charts
  3. 用 Python 来了解一下《安家》
  4. 用 Python 抓取公号文章保存成 PDF
  5. 用 Python 生成炫酷二维码及解析
  6. Using Python to grab articles with public number and save them as HTML
  7. Getting stock market trading data with Python
  8. Learn about settle down in Python
  9. Using Python to grab articles with public number and save them as PDF
  10. Using Python to generate cool two dimensional code and analysis
  11. 20210225-1 Python错误与异常
  12. 20210225-1 Python errors and exceptions
  13. 使用Python拆分、合并PDF
  14. Using Python to split and merge pdf
  15. 真工程师:20块钱做了张「名片」,可以跑Linux和Python
  16. Implementation of LSB steganography based on MATLAB and python
  17. Real Engineer: 20 yuan to make a "business card", can run Linux and python
  18. python修改微信和支付宝步数
  19. Python changes WeChat and Alipay steps
  20. Python空间分析| 01 利用Python计算全局莫兰指数(Global Moran's I)
  21. Python spatial analysis | 01 using Python to calculate global Moran's index
  22. python入门教程13-05 (python语法入门之数据备份、pymysql模块)
  23. Introduction to Python 13-05 (data backup and pymysql module of introduction to Python syntax)
  24. pandas如何操作Excel?还不会的,看此一篇足矣
  25. How does panda operate excel? Not yet. This is enough
  26. 用python连接数据库模拟用户登录
  27. Using Python to connect database to simulate user login
  28. python入门教程13-04 (语法入门之记录相关操作)
  29. Introduction to Python 13-04
  30. python入门教程13-03 (python语法入门之表相关操作)
  31. Introduction to Python 13-03
  32. python的多线程的网络爬虫,待改进
  33. Python multithreaded web crawler, to be improved
  34. 常见加密算法的Python实现:
  35. Python implementation of common encryption algorithms:
  36. python刷题-核桃的数量
  37. Number of walnuts
  38. Python爬虫知乎文章,采集新闻60秒
  39. Python crawler knows articles and collects news for 60 seconds
  40. Python爬虫知乎文章,采集新闻60秒
  41. Python crawler knows articles and collects news for 60 seconds
  42. bbox_overlaps python
  43. bbox_ overlaps python
  44. 7-43 jmu-python-字符串异常处理 (20 分)
  45. 7-43 JMU Python string exception handling (20 points)
  46. n行Python代码系列:两行代码实现视频文件转成系列图片输出
  47. N-line Python code series: two lines of code to achieve video files into a series of pictures output
  48. python-阶乘计算
  49. Python factorial calculation
  50. Python实现定时发送微信消息
  51. python爬取英雄联盟所有英雄皮肤海报
  52. Sending wechat messages regularly with Python
  53. Python crawls all hero skin posters of hero League
  54. 上手Pandas,带你玩转数据(4)-- 数据清洗
  55. Hands on pandas, take you to play with data (4) -- data cleaning
  56. Python继续霸榜,上古语言Cobol重获关注,IEEE 2020编程语言榜单揭晓
  57. 教你用 Python 下载手机小视频
  58. Python continues to dominate the list, ancient language COBOL regains attention, IEEE 2020 programming language list announced
  59. How to download small video of mobile phone with Python
  60. 如何用 Python 在京东上抢口罩