从头开始实现Transformer

GPT-3,BERT,XLNet这些都是当前自然语言处理(NLP)的新技术,它们都使用一种称为 transformer 的特殊架构组件,这是因为,transformer 这种新机制非常强大,完整的transformer 通常包含三个结构:

  • scaled dot-product attention
    • self-attention
    • cross-attention
  • multi-head attention
  • positional encoding

从头开始实现Transformer-图片1

让我们从Scaled Dot-Product Attention开始,因为我们还需要它来构建 Multi-Head Attention。

Scaled Dot-Product Attention

从头开始实现Transformer-图片2

在数学上,Scaled Dot-Product Attention表示为:
从头开始实现Transformer-图片3

Q,K和V是经过卷积后得到的特征,其形状为(batch_size,seq_length,num_features)。

将查询(Q)和键(K)相乘会得到(batch_size,seq_length,seq_length)特征,这大致告诉我们序列中每个元素的重要性,确定我们“注意”哪些元素。 注意数组使用softmax标准化,因此所有权重之和为1。 最后,注意力将通过矩阵乘法应用于值(V)数组。

scaled dot-product attention 的代码 非常简单-只需几个矩阵乘法,再加上softmax函数。 为了更加简单,我们省略了可选的Mask操作。

  1. from torch import Tensor
  2. import torch.nn.functional as f
  3.  
  4.  
  5. def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor:
  6. temp = query.bmm(key.transpose(1, 2))
  7. scale = query.size(-1) ** 0.5
  8. softmax = f.softmax(temp / scale, dim=-1)
  9. return softmax.bmm(value)

请注意,MatMul操作在PyTorch中对应为torch.bmm。 这是因为Q,K和V(查询,键和值数组)都是矩阵,每个矩阵的形状均为(batch_size,sequence_length,num_features),矩阵乘法仅在最后两个维度上执行。

在了解了Scaled Dot-Product Attention之后,就很容易理解self-attention和cross-attention了,区别仅仅是Q,K和V的来源不同。

  • self-attention的Q,K和V都是同一个输入, 即当前序列由上一层输出的高维表达。
  • cross-attention的Q代表当前序列;而K和V是同一个输入,对应的是encoder最后一层的输出结果

Multi-Head Attention

从头开始实现Transformer-图片4

从上图可以看出, Multi-Head Attention 由几个相同的Head Attention组成。 每个关注头包含3个线性层,
从头开始实现Transformer-图片5

代码如下:

  1. import torch
  2. from torch import nn
  3.  
  4.  
  5. class HeadAttention(nn.Module):
  6. def __init__(self, dim_in: int, dim_k: int, dim_v: int):
  7. super().__init__()
  8. self.q = nn.Linear(dim_in, dim_k)
  9. self.k = nn.Linear(dim_in, dim_k)
  10. self.v = nn.Linear(dim_in, dim_v)
  11.  
  12. def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
  13. return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))

现在,建立Multi-Head Attention 就非常容易。 只需将num_heads个不同的关注头和一个Linear层组合在一起即可输出。

  1. class MultiHeadAttention(nn.Module):
  2. def __init__(self, num_heads: int, dim_in: int, dim_k: int, dim_v: int):
  3. super().__init__()
  4. self.heads = nn.ModuleList(
  5. [HeadAttention(dim_in, dim_k, dim_v) for _ in range(num_heads)]
  6. )
  7. self.linear = nn.Linear(num_heads * dim_v, dim_in)
  8.  
  9. def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
  10. return self.linear(
  11. torch.cat([h(query, key, value) for h in self.heads], dim=-1)
  12. )

Positional Encoding

在构建完整的transformer之前,我们还需要一个组件:Positional Encoding。 请注意,MultiHeadAttention没有在序列维度上运行, 一切都在特征维上进行,因此它与序列长度无关。 我们必须向模型提供位置信息,以便它知道输入序列中数据点的相对位置。

transformer 论文里使用三角函数对位置信息进行编码:
从头开始实现Transformer-图片6

为什么使用正弦编码呢? 因为正弦/余弦函数是周期性的,并且它们覆盖[0,1]的范围。所以,尽管事实证明学习的嵌入表现出同样良好的效果,但作者仍然选择使用正弦编码。

我们只需几行代码即可实现:

  1. def position_encoding(
  2. seq_len: int, dim_model: int, device: torch.device = torch.device("cpu"),
  3. ) -> Tensor:
  4. pos = torch.arange(seq_len, dtype=torch.float, device=device).reshape(1, -1, 1)
  5. dim = torch.arange(dim_model, dtype=torch.float, device=device).reshape(1, 1, -1)
  6. phase = (pos / 1e4) ** (dim // dim_model)
  7.  
  8. return torch.where(dim.long() % 2 == 0, -torch.sin(phase), torch.cos(phase))

Transformer

最后,我们准备构建“Transformer”了! 让我们再看一下完整的网络图:
从头开始实现Transformer-图片7

注意,transformer使用编码器-解码器体系结构。 编码器(左)处理输入序列并返回特征向量(或存储向量)。 解码器处理目标序列,并合并来自编码器存储器的信息。 解码器的输出是我们模型的预测!

我们可以彼此独立地对编码器/解码器模块进行编码,然后最后将它们组合。 首先,我们先构建encoder。如下:

  1. def feed_forward(dim_input: int = 512, dim_feedforward: int = 2048) -> nn.Module:
  2. return nn.Sequential(
  3. nn.Linear(dim_input, dim_feedforward),
  4. nn.ReLU(),
  5. nn.Linear(dim_feedforward, dim_input),
  6. )
  7.  
  8. class Residual(nn.Module):
  9. def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
  10. super().__init__()
  11. self.sublayer = sublayer
  12. self.norm = nn.LayerNorm(dimension)
  13. self.dropout = nn.Dropout(dropout)
  14.  
  15. def forward(self, *tensors: Tensor) -> Tensor:
  16. # Assume that the "value" tensor is given last, so we can compute the
  17. # residual. This matches the signature of 'MultiHeadAttention'.
  18. return self.norm(tensors[-1] self.dropout(self.sublayer(*tensors)))
  19.  
  20. class TransformerEncoderLayer(nn.Module):
  21. def __init__(
  22. self,
  23. dim_model: int = 512,
  24. num_heads: int = 6,
  25. dim_feedforward: int = 2048,
  26. dropout: float = 0.1,
  27. ):
  28. super().__init__()
  29. dim_k = dim_v = dim_model // num_heads
  30. self.attention = Residual(
  31. MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
  32. dimension=dim_model,
  33. dropout=dropout,
  34. )
  35. self.feed_forward = Residual(
  36. feed_forward(dim_model, dim_feedforward),
  37. dimension=dim_model,
  38. dropout=dropout,
  39. )
  40.  
  41. def forward(self, src: Tensor) -> Tensor:
  42. src = self.attention(src, src, src)
  43. return self.feed_forward(src)
  44.  
  45.  
  46. class TransformerEncoder(nn.Module):
  47. def __init__(
  48. self,
  49. num_layers: int = 6,
  50. dim_model: int = 512,
  51. num_heads: int = 8,
  52. dim_feedforward: int = 2048,
  53. dropout: float = 0.1,
  54. ):
  55. super().__init__()
  56. self.layers = nn.ModuleList([
  57. TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
  58. for _ in range(num_layers)
  59. ])
  60.  
  61. def forward(self, src: Tensor) -> Tensor:
  62. seq_len, dimension = src.size(1), src.size(2)
  63. src = position_encoding(seq_len, dimension)
  64. for layer in self.layers:
  65. src = layer(src)
  66.  
  67. return src

解码器模块非常相似。只是一些小的区别:

  • 解码器接受两个参数(target和memory),而不是一个;
  • 每层有两个多头部注意力模块,而不是一个;
  • 第二个多头注意力接受两个输入的记忆;
  • 解码器中包含了self-attention和cross-attention。
  1. class TransformerDecoderLayer(nn.Module):
  2. def __init__(
  3. self,
  4. dim_model: int = 512,
  5. num_heads: int = 6,
  6. dim_feedforward: int = 2048,
  7. dropout: float = 0.1,
  8. ):
  9. super().__init__()
  10. dim_k = dim_v = dim_model // num_heads
  11. self.attention_1 = Residual(
  12. MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
  13. dimension=dim_model,
  14. dropout=dropout,
  15. )
  16. self.attention_2 = Residual(
  17. MultiHeadAttention(num_heads, dim_model, dim_k, dim_v),
  18. dimension=dim_model,
  19. dropout=dropout,
  20. )
  21. self.feed_forward = Residual(
  22. feed_forward(dim_model, dim_feedforward),
  23. dimension=dim_model,
  24. dropout=dropout,
  25. )
  26.  
  27. def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
  28. tgt = self.attention_1(tgt, tgt, tgt)
  29. tgt = self.attention_2(memory, memory, tgt)
  30. return self.feed_forward(tgt)
  31.  
  32.  
  33. class TransformerDecoder(nn.Module):
  34. def __init__(
  35. self,
  36. num_layers: int = 6,
  37. dim_model: int = 512,
  38. num_heads: int = 8,
  39. dim_feedforward: int = 2048,
  40. dropout: float = 0.1,
  41. ):
  42. super().__init__()
  43. self.layers = nn.ModuleList([
  44. TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
  45. for _ in range(num_layers)
  46. ])
  47. self.linear = nn.Linear(dim_model, dim_model)
  48.  
  49. def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
  50. seq_len, dimension = tgt.size(1), tgt.size(2)
  51. tgt = position_encoding(seq_len, dimension)
  52. for layer in self.layers:
  53. tgt = layer(tgt, memory)
  54.  
  55. return torch.softmax(self.linear(tgt), dim=-1)

最后,我们需要将所有内容打包成一个Transformer类,只要把一个编码器和解码器放在一起,然后以正确的顺序通过它们传递数据。

  1. class Transformer(nn.Module):
  2. def __init__(
  3. self,
  4. num_encoder_layers: int = 6,
  5. num_decoder_layers: int = 6,
  6. dim_model: int = 512,
  7. num_heads: int = 6,
  8. dim_feedforward: int = 2048,
  9. dropout: float = 0.1,
  10. activation: nn.Module = nn.ReLU(),
  11. ):
  12. super().__init__()
  13. self.encoder = TransformerEncoder(
  14. num_layers=num_encoder_layers,
  15. dim_model=dim_model,
  16. num_heads=num_heads,
  17. dim_feedforward=dim_feedforward,
  18. dropout=dropout,
  19. )
  20. self.decoder = TransformerDecoder(
  21. num_layers=num_decoder_layers,
  22. dim_model=dim_model,
  23. num_heads=num_heads,
  24. dim_feedforward=dim_feedforward,
  25. dropout=dropout,
  26. )
  27.  
  28. def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
  29. return self.decoder(tgt, self.encoder(src))

让我们创建一个简单的测试,作为实现的健全性检查。我们可以构造src和tgt的随机张量,检查我们的模型执行没有错误,并确认输出张量具有正确的形状。

  1. src = torch.rand(64, 16, 512)
  2. tgt = torch.rand(64, 16, 512)
  3. out = Transformer()(src, tgt)
  4. print(out.shape)
  5. # torch.Size([64, 16, 512])

Conclusions

希望这篇有助于了解transformer是如何搭建的,以及它们是如何工作的。计算机视觉领域,以前可能没有遇到过这些模型,但DETR和ViT已经取得了突破性的成果,预计在未来几年里会看到更多这样的模型。

发表评论

匿名网友

拖动滑块以完成验证
加载失败