基于pytorch实现transformer以及longformer以及代码详解

以下是一个基于PyTorch实现Transformer模型的简单示例代码,并对每个步骤进行了详细的注释。请注意,这个示例主要着重于模型的实现,不包括数据处理和训练部分。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5.  
  6. # 定义位置编码器
  7. class PositionalEncoder(nn.Module):
  8. def __init__(self, d_model, max_len=512):
  9. super(PositionalEncoder, self).__init__()
  10. self.d_model = d_model
  11. self.max_len = max_len
  12. # 计算位置编码表
  13. pe = torch.zeros(max_len, d_model) # 创建一个形状为(max_len, d_model)的零张量pe,用于存储位置编码
  14. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 创建一个形状为(max_len, 1)的张量,表示位置
  15. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # 计算位置编码中的分母部分
  16. pe[:, 0::2] = torch.sin(position * div_term) # 计算位置编码中偶数位置的值
  17. pe[:, 1::2] = torch.cos(position * div_term) # 计算位置编码中奇数位置的值
  18. pe = pe.unsqueeze(0) # 在第0维上增加一维,用于处理批次数据
  19. self.register_buffer('pe', pe) # 将位置编码表pe注册为模型的缓冲区
  20. def forward(self, x):
  21. # 输入x的维度为(batch_size, seq_len, d_model)
  22. x = x * math.sqrt(self.d_model) # 对输入乘以一个缩放因子,以便缓解梯度消失问题
  23. seq_len = x.size(1) # 获取输入序列的长度
  24. # 将位置编码添加到输入中
  25. x = x + self.pe[:, :seq_len] # 在对应位置添加位置编码
  26. return x
  27.  
  28. # 定义多头注意力机制
  29. class MultiHeadAttention(nn.Module):
  30. def __init__(self, d_model, num_heads):
  31. super(MultiHeadAttention, self).__init__()
  32. assert d_model % num_heads == 0, "d_model必须被num_heads整除"
  33. self.d_model = d_model
  34. self.num_heads = num_heads
  35. self.head_dim = d_model // num_heads # 每个头的维度
  36. # 定义线性变换层
  37. self.W_q = nn.Linear(d_model, d_model) # 查询向量的线性变换层
  38. self.W_k = nn.Linear(d_model, d_model) # 键向量的线性变换层
  39. self.W_v = nn.Linear(d_model, d_model) # 值向量的线性变换层
  40. self.W_o = nn.Linear(d_model, d_model) # 输出向量的线性变换层
  41. def forward(self, query, key, value, mask=None):
  42. batch_size = query.size(0) # 获取批次大小
  43. # 将输入的query、key、value通过线性变换得到Q、K、V
  44. Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 计算查询向量Q
  45. K = self.W_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 计算键向量K
  46. V = self.W_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 计算值向量V
  47. # 计算注意力分数
  48. attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # 计算注意力分数
  49. if mask is not None:
  50. attention_scores = attention_scores.masked_fill(mask == 0, float('-inf')) # 使用mask处理注意力分数
  51. attention_weights = F.softmax(attention_scores, dim=-1) # 计算注意力权重
  52. # 计算注意力值
  53. attention_output = torch.matmul(attention_weights, V) # 计算注意力值
  54. attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 调整注意力值的形状
  55. # 经过线性变换得到最终输出
  56. output = self.W_o(attention_output) # 最终输出
  57. return output
  58.  
  59. # 定义前向传播层
  60. class FeedForward(nn.Module):
  61. def __init__(self, d_model, d_ff):
  62. super(FeedForward, self).__init__()
  63. self.d_model = d_model
  64. self.d_ff = d_ff
  65. # 定义两个线性变换层
  66. self.linear1 = nn.Linear(d_model, d_ff) # 第一个线性变换层
  67. self.linear2 = nn.Linear(d_ff, d_model) # 第二个线性变换层
  68. def forward(self, x):
  69. x = F.relu(self.linear1(x)) # 使用ReLU激活函数进行非线性变换
  70. x = self.linear2(x) # 进行第二个线性变换
  71. return x
  72.  
  73. # 定义一个Transformer模型
  74. class Transformer(nn.Module):
  75. def __init__(self, d_model, num_heads, d_ff, num_layers):
  76. super(Transformer, self).__init__()
  77. self.d_model = d_model
  78. self.num_heads = num_heads
  79. self.d_ff = d_ff
  80. self.num_layers = num_layers
  81. # 定义多个编码器层
  82. self.encoder_layers = nn.ModuleList([
  83. nn.ModuleList([
  84. MultiHeadAttention(d_model, num_heads),
  85. nn.LayerNorm(d_model),
  86. FeedForward(d_model, d_ff),
  87. nn.LayerNorm(d_model)
  88. ])
  89. for _ in range(num_layers)
  90. ])
  91. def forward(self, src, mask=None):
  92. x = src
  93. # 通过多个编码器层进行前向传播
  94. for i in range(self.num_layers):
  95. # 多头注意力层
  96. attention = self.encoder_layers[i][0]
  97. norm1 = self.encoder_layers[i][1]
  98. x = x + attention(x, x, x, mask=mask)
  99. x = norm1(x)
  100. # 前向传播层
  101. feed_forward = self.encoder_layers[i][2]
  102. norm2 = self.encoder_layers[i][3]
  103. x = x + feed_forward(x)
  104. x = norm2(x)
  105. return x
  106.  
  107. # 测试Transformer模型
  108. if __name__ == "__main__":
  109. # 假设输入维度为(16, 20, 512),即(batch_size, seq_len, d_model)
  110. input_tensor = torch.randn(16, 20, 512)
  111. transformer_model = Transformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6)
  112. output = transformer_model(input_tensor)
  113. print(output.shape) # 输出:torch.Size([16, 20, 512])

 

以下是longfromer的pytorch实现版本

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5.  
  6. # 定义位置编码器
  7. class PositionalEncoder(nn.Module):
  8. def __init__(self, d_model, max_len=512):
  9. super(PositionalEncoder, self).__init__()
  10. self.d_model = d_model
  11. self.max_len = max_len
  12. # 计算位置编码表
  13. pe = torch.zeros(max_len, d_model)
  14. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  15. div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  16. pe[:, 0::2] = torch.sin(position * div_term)
  17. pe[:, 1::2] = torch.cos(position * div_term)
  18. pe = pe.unsqueeze(0)
  19. self.register_buffer('pe', pe)
  20. def forward(self, x):
  21. x = x * math.sqrt(self.d_model)
  22. seq_len = x.size(1)
  23. x = x + self.pe[:, :seq_len]
  24. return x
  25.  
  26. # 定义局部注意力机制
  27. class LocalAttention(nn.Module):
  28. def __init__(self, d_model, local_window):
  29. super(LocalAttention, self).__init__()
  30. self.d_model = d_model
  31. self.local_window = local_window
  32. self.attention = nn.MultiheadAttention(d_model, 1)
  33. def forward(self, x, mask=None):
  34. batch_size, seq_len, _ = x.size()
  35. local_mask = torch.zeros(seq_len, seq_len)
  36. for i in range(seq_len):
  37. local_mask[i, max(0, i - self.local_window):i + self.local_window + 1] = 1
  38. local_mask = local_mask.unsqueeze(0).to(x.device)
  39. local_mask = local_mask * mask.unsqueeze(1) if mask is not None else local_mask
  40. return self.attention(x.permute(1, 0, 2), x.permute(1, 0, 2), x.permute(1, 0, 2), key_padding_mask=local_mask)
  41.  
  42. # 定义Longformer模型
  43. class Longformer(nn.Module):
  44. def __init__(self, d_model, num_heads, d_ff, num_layers, local_window):
  45. super(Longformer, self).__init__()
  46. self.d_model = d_model
  47. self.num_heads = num_heads
  48. self.d_ff = d_ff
  49. self.num_layers = num_layers
  50. self.local_window = local_window
  51. # 定义位置编码器和局部注意力层
  52. self.positional_encoder = PositionalEncoder(d_model)
  53. self.local_attention = LocalAttention(d_model, local_window)
  54. # 定义多个编码器层
  55. self.encoder_layers = nn.ModuleList([
  56. nn.ModuleList([
  57. nn.LayerNorm(d_model),
  58. nn.Linear(d_model, d_ff),
  59. nn.ReLU(),
  60. nn.Linear(d_ff, d_model),
  61. nn.LayerNorm(d_model)
  62. ])
  63. for _ in range(num_layers)
  64. ])
  65. def forward(self, src, mask=None):
  66. x = self.positional_encoder(src)
  67. for i in range(self.num_layers):
  68. norm1 = self.encoder_layers[i][0]
  69. linear1 = self.encoder_layers[i][1]
  70. relu = self.encoder_layers[i][2]
  71. linear2 = self.encoder_layers[i][3]
  72. norm2 = self.encoder_layers[i][4]
  73. # 局部注意力层
  74. if mask is not None:
  75. mask[:, :, :self.local_window] = 0
  76. x = x + self.local_attention(x.permute(1, 0, 2), mask=mask)[0].permute(1, 0, 2)
  77. # 前向传播层
  78. x = norm1(x)
  79. x = linear2(relu(linear1(x))) + x
  80. x = norm2(x)
  81. return x
  82.  
  83. # 测试Longformer模型
  84. if __name__ == "__main__":
  85. input_tensor = torch.randn(16, 512, 512) # 假设输入维度为(16, 512, 512)
  86. mask = torch.ones(16, 512) # 假设有512个标记
  87. longformer_model = Longformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6, local_window=128)
  88. output = longformer_model(input_tensor, mask=mask)
  89. print(output.shape) # 输出:torch.Size([16, 512, 512])

 

发表评论

匿名网友

拖动滑块以完成验证
加载失败