• Transformer的Encoder和Decoder之间的交互


    Transformer的Encoder和Decoder之间的交互

    flyfish

    这个示例代码创建了一个小的Transformer模型,并演示了如何在Encoder和Decoder之间进行交互。

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    
    # 定义位置编码
    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=5000):
            super(PositionalEncoding, self).__init__()
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0).transpose(0, 1)
            self.register_buffer('pe', pe)
    
        def forward(self, x):
            return x + self.pe[:x.size(0), :]
    
    # 定义Transformer模型
    class TransformerModel(nn.Module):
        def __init__(self, input_dim, model_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward=512, max_len=5000):
            super(TransformerModel, self).__init__()
            self.model_dim = model_dim
            self.embedding = nn.Embedding(input_dim, model_dim)
            self.positional_encoding = PositionalEncoding(model_dim, max_len)
    
            self.encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(model_dim, nhead, dim_feedforward), num_encoder_layers)
    
            self.decoder = nn.TransformerDecoder(
                nn.TransformerDecoderLayer(model_dim, nhead, dim_feedforward), num_decoder_layers)
    
            self.fc_out = nn.Linear(model_dim, output_dim)
    
        def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
            src_emb = self.positional_encoding(self.embedding(src) * math.sqrt(self.model_dim))
            tgt_emb = self.positional_encoding(self.embedding(tgt) * math.sqrt(self.model_dim))
    
            memory = self.encoder(src_emb, mask=src_mask)
            output = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
            return self.fc_out(output)
    
    # 超参数定义
    input_dim = 1000
    model_dim = 512
    output_dim = 1000
    nhead = 8
    num_encoder_layers = 6
    num_decoder_layers = 6
    dim_feedforward = 2048
    max_len = 5000
    
    # 创建Transformer模型实例
    model = TransformerModel(input_dim, model_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_len)
    
    # 定义示例输入
    src = torch.randint(0, input_dim, (10, 32))  # (source sequence length, batch size)
    tgt = torch.randint(0, input_dim, (20, 32))  # (target sequence length, batch size)
    
    # 前向传播
    output = model(src, tgt)
    
    # 打印输出形状
    print(output.shape)  # (target sequence length, batch size, output dimension)
    

    输出

    torch.Size([20, 32, 1000])
    

    位置编码 (Positional Encoding):
    用于在输入中加入位置信息,使模型能够考虑序列顺序。

    Transformer模型 (TransformerModel):
    包括Embedding层、位置编码层、Encoder、Decoder和输出的全连接层。

    前向传播:
    将输入源序列 (src) 和目标序列 (tgt) 通过嵌入层和位置编码。
    使用Encoder对源序列进行编码,得到记忆 (memory)。
    使用Decoder对目标序列进行解码,结合记忆生成输出。

    超参数:
    定义模型的维度、头数、层数等。

    register_buffer 是 PyTorch 中 nn.Module 类的方法,用于注册一个持久的缓冲区,这些缓冲区不是模型的参数,但在训练和推理过程中需要被保存和加载。例如,位置编码就是这样一种缓冲区,它不需要进行梯度更新,但需要在模型保存和加载时保持不变。

    下面是一个简单的例子,展示如何使用 register_buffer 注册一个缓冲区:

    import torch
    import torch.nn as nn
    
    class ExampleModule(nn.Module):
        def __init__(self):
            super(ExampleModule, self).__init__()
            # 注册一个缓冲区
            buffer = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
            self.register_buffer('my_buffer', buffer)
            # 一个简单的线性层
            self.linear = nn.Linear(4, 2)
    
        def forward(self, x):
            # 使用缓冲区进行一些操作
            x = x + self.my_buffer
            return self.linear(x)
    
    # 创建模型实例
    model = ExampleModule()
    
    # 打印模型结构
    print(model)
    
    # 定义输入张量
    input_tensor = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
    
    # 前向传播
    output = model(input_tensor)
    print(output)
    
    # 打印缓冲区
    print("Buffer:", model.my_buffer)
    

    输出

    ExampleModule(
      (linear): Linear(in_features=4, out_features=2, bias=True)
    )
    tensor([-3.2452,  0.5913], grad_fn=)
    Buffer: tensor([1., 2., 3., 4.])
    
  • 相关阅读:
    帝国模板留言板增加自定义字段教程
    2-1线性表-顺序表
    华为数通方向HCIP-DataCom H12-821题库(单选题:221-240)
    java毕业设计晶研电子公司业务网站(附源码、数据库)
    【SRE】MySQL8使用方式
    牛客网——Java刷题篇
    ConsoleAppender简介说明
    jquery 选择器深入
    全球二氧化碳排放数据1deg产品(ODIAC)数据
    车载前置摄像头学习笔记 ———— 摄像头输出数据格式(JPEG)
  • 原文地址:https://blog.csdn.net/flyfish1986/article/details/139472618