基于pytorch实现transformer以及longformer以及代码详解

以下是一个基于PyTorch实现Transformer模型的简单示例代码,并对每个步骤进行了详细的注释。请注意,这个示例主要着重于模型的实现,不包括数据处理和训练部分。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 定义位置编码器
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoder, self).__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # 计算位置编码表
        pe = torch.zeros(max_len, d_model)  # 创建一个形状为(max_len, d_model)的零张量pe,用于存储位置编码
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # 创建一个形状为(max_len, 1)的张量,表示位置
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # 计算位置编码中的分母部分
        pe[:, 0::2] = torch.sin(position * div_term)  # 计算位置编码中偶数位置的值
        pe[:, 1::2] = torch.cos(position * div_term)  # 计算位置编码中奇数位置的值
        pe = pe.unsqueeze(0)  # 在第0维上增加一维,用于处理批次数据
        self.register_buffer('pe', pe)  # 将位置编码表pe注册为模型的缓冲区
        
    def forward(self, x):
        # 输入x的维度为(batch_size, seq_len, d_model)
        x = x * math.sqrt(self.d_model)  # 对输入乘以一个缩放因子,以便缓解梯度消失问题
        seq_len = x.size(1)  # 获取输入序列的长度
        # 将位置编码添加到输入中
        x = x + self.pe[:, :seq_len]  # 在对应位置添加位置编码
        return x

# 定义多头注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model必须被num_heads整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  # 每个头的维度
        
        # 定义线性变换层
        self.W_q = nn.Linear(d_model, d_model)  # 查询向量的线性变换层
        self.W_k = nn.Linear(d_model, d_model)  # 键向量的线性变换层
        self.W_v = nn.Linear(d_model, d_model)  # 值向量的线性变换层
        self.W_o = nn.Linear(d_model, d_model)  # 输出向量的线性变换层
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)  # 获取批次大小
        
        # 将输入的query、key、value通过线性变换得到Q、K、V
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # 计算查询向量Q
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # 计算键向量K
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # 计算值向量V
        
        # 计算注意力分数
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # 计算注意力分数
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))  # 使用mask处理注意力分数
        attention_weights = F.softmax(attention_scores, dim=-1)  # 计算注意力权重
        
        # 计算注意力值
        attention_output = torch.matmul(attention_weights, V)  # 计算注意力值
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # 调整注意力值的形状
        
        # 经过线性变换得到最终输出
        output = self.W_o(attention_output)  # 最终输出
        return output

# 定义前向传播层
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        
        # 定义两个线性变换层
        self.linear1 = nn.Linear(d_model, d_ff)  # 第一个线性变换层
        self.linear2 = nn.Linear(d_ff, d_model)  # 第二个线性变换层
        
    def forward(self, x):
        x = F.relu(self.linear1(x))  # 使用ReLU激活函数进行非线性变换
        x = self.linear2(x)  # 进行第二个线性变换
        return x

# 定义一个Transformer模型
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super(Transformer, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.num_layers = num_layers
        
        # 定义多个编码器层
        self.encoder_layers = nn.ModuleList([
            nn.ModuleList([
                MultiHeadAttention(d_model, num_heads),
                nn.LayerNorm(d_model),
                FeedForward(d_model, d_ff),
                nn.LayerNorm(d_model)
            ])
            for _ in range(num_layers)
        ])
        
    def forward(self, src, mask=None):
        x = src
        
        # 通过多个编码器层进行前向传播
        for i in range(self.num_layers):
            # 多头注意力层
            attention = self.encoder_layers[i][0]
            norm1 = self.encoder_layers[i][1]
            x = x + attention(x, x, x, mask=mask)
            x = norm1(x)
            
            # 前向传播层
            feed_forward = self.encoder_layers[i][2]
            norm2 = self.encoder_layers[i][3]
            x = x + feed_forward(x)
            x = norm2(x)
        
        return x

# 测试Transformer模型
if __name__ == "__main__":
    # 假设输入维度为(16, 20, 512),即(batch_size, seq_len, d_model)
    input_tensor = torch.randn(16, 20, 512)
    transformer_model = Transformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6)
    output = transformer_model(input_tensor)
    print(output.shape)  # 输出:torch.Size([16, 20, 512])

 

以下是longfromer的pytorch实现版本

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 定义位置编码器
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoder, self).__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # 计算位置编码表
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return x

# 定义局部注意力机制
class LocalAttention(nn.Module):
    def __init__(self, d_model, local_window):
        super(LocalAttention, self).__init__()
        self.d_model = d_model
        self.local_window = local_window
        self.attention = nn.MultiheadAttention(d_model, 1)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        local_mask = torch.zeros(seq_len, seq_len)
        for i in range(seq_len):
            local_mask[i, max(0, i - self.local_window):i + self.local_window + 1] = 1
        local_mask = local_mask.unsqueeze(0).to(x.device)
        local_mask = local_mask * mask.unsqueeze(1) if mask is not None else local_mask
        
        return self.attention(x.permute(1, 0, 2), x.permute(1, 0, 2), x.permute(1, 0, 2), key_padding_mask=local_mask)

# 定义Longformer模型
class Longformer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers, local_window):
        super(Longformer, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.num_layers = num_layers
        self.local_window = local_window
        
        # 定义位置编码器和局部注意力层
        self.positional_encoder = PositionalEncoder(d_model)
        self.local_attention = LocalAttention(d_model, local_window)
        
        # 定义多个编码器层
        self.encoder_layers = nn.ModuleList([
            nn.ModuleList([
                nn.LayerNorm(d_model),
                nn.Linear(d_model, d_ff),
                nn.ReLU(),
                nn.Linear(d_ff, d_model),
                nn.LayerNorm(d_model)
            ])
            for _ in range(num_layers)
        ])
        
    def forward(self, src, mask=None):
        x = self.positional_encoder(src)
        for i in range(self.num_layers):
            norm1 = self.encoder_layers[i][0]
            linear1 = self.encoder_layers[i][1]
            relu = self.encoder_layers[i][2]
            linear2 = self.encoder_layers[i][3]
            norm2 = self.encoder_layers[i][4]
            
            # 局部注意力层
            if mask is not None:
                mask[:, :, :self.local_window] = 0
            x = x + self.local_attention(x.permute(1, 0, 2), mask=mask)[0].permute(1, 0, 2)
            
            # 前向传播层
            x = norm1(x)
            x = linear2(relu(linear1(x))) + x
            x = norm2(x)
        
        return x

# 测试Longformer模型
if __name__ == "__main__":
    input_tensor = torch.randn(16, 512, 512)  # 假设输入维度为(16, 512, 512)
    mask = torch.ones(16, 512)  # 假设有512个标记
    longformer_model = Longformer(d_model=512, num_heads=8, d_ff=2048, num_layers=6, local_window=128)
    output = longformer_model(input_tensor, mask=mask)
    print(output.shape)  # 输出:torch.Size([16, 512, 512])

 

发表评论

匿名网友