zenRRan
2021-04-08 11:54:48

practice
python
nn.transformer
nn
transformer

pytorch It's also realized by itself transformer Model of , differ huggingface Or somewhere else ,pytorch Of mask Parameters are more difficult to understand （ Even with documentation ）, Here are some supplements and explanations .（ By the way , there transformer It needs to be done on its own position embedding Of , Don't be happy, just run the data ）

>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) >>> src = torch.rand((10, 32, 512)) >>> tgt = torch.rand((20, 32, 512)) >>> out = transformer_model(src, tgt) # It didn't come true position embedding , You need to do it yourself mask Mechanism . Otherwise it's not what you think transformer

First of all, let's look at the parameters of the official website

**src**– the sequence to the encoder (required).**tgt**– the sequence to the decoder (required).**src_mask**– the additive mask for the src sequence (optional).**tgt_mask**– the additive mask for the tgt sequence (optional).**memory_mask**– the additive mask for the encoder output (optional).**src_key_padding_mask**– the ByteTensor mask for src keys per batch (optional).**tgt_key_padding_mask**– the ByteTensor mask for tgt keys per batch (optional).**memory_key_padding_mask**– the ByteTensor mask for memory keys per batch (optional).

The biggest difference is **mask_ and **_key_padding_mask,_ as for * yes src still tgt,memory, It doesn't matter , The module appears in encoder, Namely src, Appear in the decoder, Namely tgt,decoder Every block The second layer of and encoder do cross attention When , Namely memory.

*mask Corresponding API yes **attn_mask**,**_key_padding_mask Corresponding API yes ***key_padding_mask**

Let's see. torch/nn/modules/activation.py among **MultiheadAttention modular ** For this 2 individual API The explanation of ：

def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] r""" Args: query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. """

- key_padding_mask： To cover up <PAD> To avoid pad token Of embedding Input . Shape requirements ：（N,S）
- attn_mask：2 Weior 3 A matrix of dimensions . To avoid a given location embedding Input .2 The dimensional matrix shape requires ：（L, S）; Also support 3 Dimension matrix input , Shape requirements ：（N*num_heads, L, S）

among ,N yes batch size Size ,L Is the length of the target sequence (the target sequence length),S Is the length of the source sequence (the source sequence length). This module will appear in the picture above 3 An orange area , therefore the target sequence It doesn't necessarily mean decoder Input sequence ,the source sequence Not necessarily encoder Input sequence .

A better understanding is ,target sequence It's for bulls attention among q（ Inquire about ） Sequence ,source sequence representative k（ Key value ） and v（ value ） Sequence . for example , When decoder Doing it self-attention When ,target sequence and source sequence It's all about itself , So at this time L=S, All are decoder The length of the encoded sequence .

Here is a simple example ：

Now there is a batch,batch_size = 3, The length is 4,token The form of expression is as follows ：

[ [‘a’,'b','c','<PAD>'], [‘a’,'b','c','d'], [‘a’,'b','<PAD>','<PAD>'] ]

Now suppose you're going to do it self-attention The calculation of （ Can be in encoder, It can also be in decoder）, So take the third line of data as an example ,‘a’ Doing it qkv When calculating , Will see 'b','<PAD>','<PAD>', But we don't want to ‘a’ notice '<PAD>', Because they are meaningless in themselves , therefore , need **key_padding_mask** Cover them .

key_padding_mask The shape and size of （N,S）, So here's an example ,key_padding_mask In the following form ,key_padding_mask.shape = （3,4）：

[ [False, False, False, True], [False, False, False, False], [False, False, True, True] ]

It's worth noting that ,key_padding_mask It's essentially a cover up key The value of this position （ Set up 0）, however <PAD> token In itself , Also can do qkv Of calculation , Take the third position of the third row of data as an example , its q yes <PAD> Of embedding,k and v Each of them is the first ‘a’ And the second one ‘b’, It will also output a embedding.

So your model training is transformer final output Calculation loss When , You also need to specify ignoreindex=pad_index. Take the third line of data as an example , Its supervisory signal is [3205,1890,0,0],pad_index=0 . In this way , Even in <PAD> Of transformer It's going to be crazy and meaningful position do qkv, Will be output embedding, But we don't count it loss, Let it do all kinds of demons .

At first I saw 2 individual mask Parameter time , I'm also confused , And their shape It's not the same .attn_mask Where on earth is it used ？

decoder Doing it self-attention When , Each position is different from encoder, He can only see the above information .key_padding_mask Of shape by (batch_size, source_length), It means that every position has query, What he saw went through key_padding_mask It's all the same after （ Even though he can do it batch Each row of data mask Is not the same ）, This does not meet the requirements of the following modules ：

decoder Of mask Multi attention module

What's needed here mask as follows ：

Yellow is the visible part , Purple is the invisible part , Different locations need mask It's not the same part

and pytorch Of nn.Transformer We already have functions that help us implement ：

def generate_square_subsequent_mask(self, sz: int) -> Tensor: r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask

Or the example above , Take the first line of data ['a','b','c','<PAD>'], For example （ Suppose we're using decoder Make a production , Research block The first floor of layer That is to say self-attention）, here ：

- 'a' You can see 'a'
- 'b' You can see 'a','b'
- 'c' You can see 'a','b','c'
- '<PAD>' What should not be seen in theory , But as long as the overhead surveillance signal is ignore_index, Then it doesn't matter , So let him see 'a','b','c','<PAD>'

Think about it attn_mask The shape requirements of ,2 When I was young, I was （L,S）,3 When I was young, I was （N*num_heads, L, S）. here , because qkv It's all the same sequence （decoder The sequence below ） therefore L=S; And because for batch For each row of data , their mask The mechanism is the same , That is to say i Value of position , You can only see the information above , So our attn_mask Just two-dimensional , The internal implementation will put mask The matrix broadcasts to batch In each row of data ：

generally speaking , Unless you need magic transformer, For example, let different heads see different information , Otherwise, the two-dimensional matrix is enough .

I think it's best to use it according to the above convention , actually ,2 individual mask Working together on the same model , Must use attn_mask Instead of key_padding_mask hold <PAD> Cover up , OK? ？ Certainly. , It just increases your workload .

** Let's get down to business **

This article is from WeChat official account. - Deep learning of natural language processing （zenRRan）

The source and reprint of the original text are detailed in the text , If there is any infringement , Please contact the yunjia_community@tencent.com Delete .

Original publication time ： 2021-04-04

Participation of this paper Tencent cloud media sharing plan , You are welcome to join us , share .

- Drawing Doraemon in Python
- Python charts
- 用 Python 来了解一下《安家》
- 用 Python 抓取公号文章保存成 PDF
- 用 Python 生成炫酷二维码及解析
- Using Python to grab articles with public number and save them as HTML
- Getting stock market trading data with Python
- Learn about settle down in Python
- Using Python to grab articles with public number and save them as PDF
- Using Python to generate cool two dimensional code and analysis
- 20210225-1 Python错误与异常
- 20210225-1 Python errors and exceptions
- 使用Python拆分、合并PDF
- Using Python to split and merge pdf
- 真工程师：20块钱做了张「名片」，可以跑Linux和Python
- Implementation of LSB steganography based on MATLAB and python
- Real Engineer: 20 yuan to make a "business card", can run Linux and python
- python修改微信和支付宝步数
- Python changes WeChat and Alipay steps
- Python空间分析| 01 利用Python计算全局莫兰指数（Global Moran's I）
- Python spatial analysis | 01 using Python to calculate global Moran's index
- python入门教程13-05 （python语法入门之数据备份、pymysql模块）
- Introduction to Python 13-05 (data backup and pymysql module of introduction to Python syntax)
- pandas如何操作Excel？还不会的，看此一篇足矣
- How does panda operate excel? Not yet. This is enough
- 用python连接数据库模拟用户登录
- Using Python to connect database to simulate user login
- python入门教程13-04 （语法入门之记录相关操作）
- Introduction to Python 13-04
- python入门教程13-03 （python语法入门之表相关操作）
- Introduction to Python 13-03
- python的多线程的网络爬虫，待改进
- Python multithreaded web crawler, to be improved
- 常见加密算法的Python实现：
- Python implementation of common encryption algorithms:
- python刷题-核桃的数量
- Number of walnuts
- Python爬虫知乎文章，采集新闻60秒
- Python crawler knows articles and collects news for 60 seconds
- Python爬虫知乎文章，采集新闻60秒
- Python crawler knows articles and collects news for 60 seconds
- bbox_overlaps python
- bbox_ overlaps python
- 7-43 jmu-python-字符串异常处理 (20 分)
- 7-43 JMU Python string exception handling (20 points)
- n行Python代码系列：两行代码实现视频文件转成系列图片输出
- N-line Python code series: two lines of code to achieve video files into a series of pictures output
- python-阶乘计算
- Python factorial calculation
- Python实现定时发送微信消息
- python爬取英雄联盟所有英雄皮肤海报
- Sending wechat messages regularly with Python
- Python crawls all hero skin posters of hero League
- 上手Pandas，带你玩转数据（4）-- 数据清洗
- Hands on pandas, take you to play with data (4) -- data cleaning
- Python继续霸榜，上古语言Cobol重获关注，IEEE 2020编程语言榜单揭晓
- 教你用 Python 下载手机小视频
- Python continues to dominate the list, ancient language COBOL regains attention, IEEE 2020 programming language list announced
- How to download small video of mobile phone with Python
- 如何用 Python 在京东上抢口罩