• 聊聊 从源码来看ChatGLM-6B的模型结构


    基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

    概述

    ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。
    transformer结构:

    转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

    位置编码

    ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

    class RotaryEmbedding(torch.nn.Module):
        def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
            super().__init__()
            inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
            inv_freq = inv_freq.half()
            self.learnable = learnable
            if learnable:
                self.inv_freq = torch.nn.Parameter(inv_freq)
                self.max_seq_len_cached = None
            else:
                self.register_buffer('inv_freq', inv_freq)
                self.max_seq_len_cached = None
                self.cos_cached = None
                self.sin_cached = None
            self.precision = precision
    
        def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                                  error_msgs):
            pass
    
        def forward(self, x, seq_dim=1, seq_len=None):
            if seq_len is None:
                seq_len = x.shape[seq_dim]
            if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
                self.max_seq_len_cached = None if self.learnable else seq_len
                t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
                freqs = torch.einsum('i,j->ij', t, self.inv_freq)
                # Different from paper, but it uses a different permutation in order to obtain the same calculation
                emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
                if self.precision == torch.bfloat16:
                    emb = emb.float()
    
                # [sx, 1 (b * np), hn]
                cos_cached = emb.cos()[:, None, :]
                sin_cached = emb.sin()[:, None, :]
                if self.precision == torch.bfloat16:
                    cos_cached = cos_cached.bfloat16()
                    sin_cached = sin_cached.bfloat16()
                if self.learnable:
                    return cos_cached, sin_cached
                self.cos_cached, self.sin_cached = cos_cached, sin_cached
            return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
    
        def _apply(self, fn):
            if self.cos_cached is not None:
                self.cos_cached = fn(self.cos_cached)
            if self.sin_cached is not None:
                self.sin_cached = fn(self.sin_cached)
            return super()._apply(fn)
    
    ## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/
    

    激活函数

    ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

    @torch.jit.script
    def gelu_impl(x):
        """OpenAI's gelu implementation."""
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                           (1.0 + 0.044715 * x * x)))
    
    
    def gelu(x):
        return gelu_impl(x)
    

    编码器-解码器(encoder-decoder)

    接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

    from transformers import AutoTokenizer, AutoModel
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()
    
    print(mode)
    
    ## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/
    

    输出:

    ChatGLMForConditionalGeneration(
      (transformer): ChatGLMModel(
        (word_embeddings): Embedding(130528, 4096)
        (layers): ModuleList(
          (0-27): 28 x GLMBlock(
            (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (attention): SelfAttention(
              (rotary_emb): RotaryEmbedding()
              (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
              (dense): Linear(in_features=4096, out_features=4096, bias=True)
            )
            (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
            (mlp): GLU(
              (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
              (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
            )
          )
        )
        (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      )
      (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
    )
    

    从脑图的角度来梳理下其结构

    其结构图表示如下:

    将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。
    官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:
    image.png

    转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

  • 相关阅读:
    保研后,你们都怎么样了?
    node---express
    Android 9 MTK 更改系统的版本号
    【数据治理】数据治理之元数据管理的利器——Atlas入门宝典
    LCR 146.螺旋遍历数组
    Python实操如何去除EXCEL表格中的公式并保留原有的数值
    jQuery--选择器/事件/增加/删除
    算法训练与程序竞赛题目集合(L1)
    TypeScript 基础类型
    mysql-8.0.31-glibc2.12-x86_64.tar.xz 离线安装mysql8.0
  • 原文地址:https://www.cnblogs.com/zhiyong-ITNote/p/17949232