• 累计注意力大模型


    不稳定

    import paddle
    
    
    class HeadLoss(paddle.nn.Layer):
        def __init__(self):
            super(HeadLoss, self).__init__()
            self.loss = paddle.nn.CrossEntropyLoss()
    
        def forward(self, x_list, label):
            loss = 0
            h = x_list[0].shape[-1]
            p = len(x_list) + 1
            acc_data = 1
            for i, out in enumerate(x_list):
                i += 1
                one_label = (label % h ** (p - i) // h ** ((p - i) - 1)).astype("int64")
                loss += self.loss(out.reshape([-1, h]), one_label.reshape([-1]))
                with paddle.no_grad():
                    acc_data *= (paddle.argmax(out, -1) == one_label).numpy()
    
            return loss, acc_data.mean()
    
    
    class HiddenHead(paddle.nn.Layer):
        def __init__(self, voc_size=19, hidden_size=512):
            super(HiddenHead, self).__init__()
    
            self.hidden_size = hidden_size
            p = 0
            while True:
                voc_size //= hidden_size
                if voc_size == 0:
                    break
                else:
                    p += 1
            self.head = paddle.nn.LayerList(
                [paddle.nn.Linear(hidden_size, hidden_size, bias_attr=False) for _ in range(p + 1)])
    
        def forward(self, head_add_x):
            x0_list = []
            for i, head in enumerate(self.head):
                i += 1
                x0_list.append(head(head_add_x))
            return x0_list
    
        def sample(self, head_add_x):
            h = head_add_x.shape[-1]
            x0_res = 0
    
            p = len(self.head) + 1
    
            for i, head in enumerate(self.head):
                i += 1
                arg_max = head(head_add_x)
                arg_max = paddle.argmax(arg_max, -1)
                x0_res += arg_max * h ** ((p - i) - 1)
            return x0_res
    
    
    class EmAdd(paddle.nn.Layer):
        def __init__(self, voc_size=9999, hidden_size=256):
            super(EmAdd, self).__init__()
            self.hidden_size = hidden_size
            p = 0
            while True:
                voc_size //= hidden_size
                if voc_size == 0:
                    break
                else:
                    p += 1
            self.em = paddle.nn.LayerList([paddle.nn.Embedding(hidden_size, hidden_size) for _ in range(p + 1)])
            # self.em_zero = paddle.nn.LayerList([paddle.nn.Embedding(hidden_size, hidden_size,padding_idx=0)
            # for _ in range(p + 1)])
    
        def forward(self, em_add_x):
    
            add = 0
            p = len(self.em) + 1
            mask = paddle.zeros(em_add_x.shape)
            for i, em in enumerate(self.em):
                # mask 是累加不等于0
                i += 1
                x0 = em_add_x % self.hidden_size ** (p - i) // self.hidden_size ** ((p - i) - 1)
                mask += x0
                mask = mask != 0
                mask = mask.astype("int")
                x0 = em(x0) * mask.unsqueeze(-1)
                add = paddle.sin(x0 + add)
    
            return add
    
    
    class ReNet(paddle.nn.Layer):
        def __init__(self, voc_size, hidden_dim, num_head):
            super(ReNet, self).__init__()
            self.group_norm = paddle.nn.GroupNorm
            self.em = EmAdd(voc_size, hidden_dim * num_head)
            self.head = num_head
    
        def forward(self, x):
            x = self.em(x)
            seq_len = x.shape[1]
            bsz = x.shape[0]
            x = x.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
    
            q = paddle.sin(x)
            k = paddle.sin(q + x)
            v = paddle.sin(k + x)
            attention= self.ParallelRetention(seq_len,q, k, v)
            return
        def sample_forward(self, x):
            x = self.em(x)
            seq_len = x.shape[1]
            bsz = x.shape[0]
            x = x.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
    
            q = paddle.sin(x)
            k = paddle.sin(q + x)
            v = paddle.sin(k + x)
            attention,state = self.RecurrentRetention(q, k, v,paddle.zeros(q.shape),"")
            return x
    
        def ParallelRetention(self,seq_len,
                              q,  # bsz*num_head*len*qk_dim
                              k,  # bsz*num_head*len*qk_dim
                              v,  # bsz*num_head*len*v_dim
                              ):
            
            
            decay_mask = paddle.triu(paddle.ones([seq_len, seq_len])).T  # num_head*len*len
            retention = q @ k.transpose([0, 1, 3, 2])
            retention = retention * decay_mask
            output = retention @ v
    
            current_kv = k.unsqueeze(-1) * v.unsqueeze(-2)
            output = paddle.sum(q.unsqueeze(-1) * current_kv, -2)
            output = self.group_norm(self.head, self.head)(output)
            return output
    def demo():
            
        import numpy as np
        # 定义矩阵 A,B,C,D
        kup = np.array([[5, 2],
                      [3, 4]])
        
        vup = np.array([[5, 6],
                      [7, 8]])
        
        knext = np.array([[9, 10],
                      [11, 12]])
        
        vnext = np.array([[13, 14],
                      [15, 16]])
        
        # 重新排列顺序为:(kup @ (vup + vnext)) + (knext @ (vup + vnext))
        # 在kup 和 vup 完全 在s 
        left = np.dot(kup, vup + vnext) + np.dot(knext, vup + vnext)
        
        # 右边的等式:(A + C) @ (B + D)
        right = np.dot(kup + knext, vup + vnext)
        
        # 检查左边和右边是否相等
        if np.array_equal(left, right):
            print("(kup @ (vup + vnext)) + (knext @ (vup + vnext))")
        else:
            print("不相等")
            
    if __name__ == '__main__':
        net = ReNet(12935, 128, 8)
        # net(paddle.randint(1, 123, [3, 23]))
        # net.sample_forward(paddle.randint(1, 123, [3, 23]))
        s=10
        h=12
        a=paddle.randint(1, 123, [s,h]).astype("float32")
        b=paddle.randint(1, 123, [s,h]).astype("float32")
        c=paddle.randint(1, 123, [s,h]).astype("float32")
        print()
    
    
    
    
    #  A1*A4+A2*A5+A3*A6  ,A1*B4+A2*B5+A3*B6
    #  B1*A4+B2*A5+B3*A6  ,B1*B4+B2*B5+B3*B6
    
    #  A7(A1*A4+A2*A5+A3*A6) + B7*(A1*B4+A2*B5+A3*B6),  A8(A1*A4+A2*A5+A3*A6) + B8*(A1*B4+A2*B5+A3*B6)
    #  A7*(B1*A4+B2*A5+B3*A6)+B7*(B1*B4+B2*B5+B3*B6) ,A8*(B1*A4+B2*A5+B3*A6)+B8*(B1*B4+B2*B5+B3*B6)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188

    稳定

    import paddle
    import numpy as np
    
    
    class HeadLoss(paddle.nn.Layer):
        def __init__(self):
            super(HeadLoss, self).__init__()
            self.loss = paddle.nn.CrossEntropyLoss()
    
        def forward(self, x_list, label):
            loss = 0
            h = x_list[0].shape[-1]
            p = len(x_list) + 1
            acc_data = 1
            for i, out in enumerate(x_list):
                i += 1
                one_label = (label % h ** (p - i) // h ** ((p - i) - 1)).astype("int64")
                loss += self.loss(out.reshape([-1, h]), one_label.reshape([-1]))
                with paddle.no_grad():
                    acc_data *= (paddle.argmax(out, -1) == one_label).numpy()
    
            return loss, acc_data.mean()
    
    
    class HiddenHead(paddle.nn.Layer):
        def __init__(self, voc_size=19, hidden_size=512):
            super(HiddenHead, self).__init__()
    
            self.hidden_size = hidden_size
            p = 0
            while True:
                voc_size //= hidden_size
                if voc_size == 0:
                    break
                else:
                    p += 1
            self.head = paddle.nn.LayerList(
                [paddle.nn.Linear(hidden_size, hidden_size, bias_attr=False) for _ in range(p + 1)])
    
        def forward(self, head_add_x):
            x0_list = []
            for i, head in enumerate(self.head):
                i += 1
                x0_list.append(head(head_add_x))
            return x0_list
    
        def sample(self, head_add_x):
            h = head_add_x.shape[-1]
            x0_res = 0
    
            p = len(self.head) + 1
    
            for i, head in enumerate(self.head):
                i += 1
                arg_max = head(head_add_x)
                arg_max = paddle.argmax(arg_max, -1)
                x0_res += arg_max * h ** ((p - i) - 1)
            return x0_res
    
    
    class EmAdd(paddle.nn.Layer):
        def __init__(self, voc_size=9999, hidden_size=256):
            super(EmAdd, self).__init__()
            self.hidden_size = hidden_size
            p = 0
            while True:
                voc_size //= hidden_size
                if voc_size == 0:
                    break
                else:
                    p += 1
            self.em = paddle.nn.LayerList([paddle.nn.Embedding(hidden_size, hidden_size) for _ in range(p + 1)])
            # self.em_zero = paddle.nn.LayerList([paddle.nn.Embedding(hidden_size, hidden_size,padding_idx=0)
            # for _ in range(p + 1)])
    
        def forward(self, em_add_x):
    
            add = 0
            p = len(self.em) + 1
            mask = paddle.zeros(em_add_x.shape)
            for i, em in enumerate(self.em):
                # mask 是累加不等于0
                i += 1
                x0 = em_add_x % self.hidden_size ** (p - i) // self.hidden_size ** ((p - i) - 1)
                mask += x0
                mask = mask != 0
                mask = mask.astype("int")
                x0 = em(x0) * mask.unsqueeze(-1)
                add = paddle.sin(x0 + add)
    
            return add
    
    
    class ReNet(paddle.nn.Layer):
        def __init__(self, voc_size, hidden_dim, num_head_dim, n_layers):
            super(ReNet, self).__init__()
            self.em = EmAdd(voc_size, hidden_dim * num_head_dim)
            self.qk_list = paddle.nn.LayerList(
                [paddle.nn.Linear(num_head_dim, hidden_dim, bias_attr=False) for _ in range(n_layers)])
    
            self.head = num_head_dim
            self.out_layer = HiddenHead(voc_size, hidden_dim)
    
    
        def forward(self, x, state):
            x = self.em(x)
            q = paddle.sin(x)
            k = paddle.sin(x + q)
            seq_len = q.shape[1]
            bsz = q.shape[0]
            q = q.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
            k = k.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
            mask = paddle.triu(paddle.ones([seq_len, seq_len])).T
    
            qk = state + paddle.sum(paddle.sin(q @ k.transpose([0, 1, 3, 2]) * mask), -1)
            state = qk[:, :, -1:]
            new_qk=0
            for one_k in self.qk_list:
                new_qk += one_k(qk.transpose([0, 2, 1]))
                new_qk=paddle.sin(new_qk)
    
            out = self.out_layer(new_qk)
            return out, state
    
        def sample(self, x, state):
            x = self.em(x)
            q = paddle.sin(x)
            k = paddle.sin(x + q)
            seq_len = q.shape[1]
            bsz = q.shape[0]
            q = q.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
            k = k.reshape([bsz, seq_len, self.head, -1]).transpose([0, 2, 1, 3])
            mask = paddle.triu(paddle.ones([seq_len, seq_len])).T
    
            qk = state + paddle.sum(paddle.sin(q @ k.transpose([0, 1, 3, 2]) * mask), -1)
            state = qk[:, :, -1:]
            new_qk = 0
            for one_k in self.qk_list:
                new_qk += one_k(qk.transpose([0, 2, 1]))
                new_qk = paddle.sin(new_qk)
    
            out = self.out_layer.sample(new_qk)
            return out, state
    
    
    
    def emheading_train_and_sample():
        print("*" * 100)
        net = ReNet(12935, 80, 80, 8)
        # net.eval()
        x = paddle.to_tensor([
            np.random.randint(1, 124, 100),
            np.random.randint(1, 124, 100),
        ], dtype='int64')
    
        xx = x
    
        # 模拟训练
    
        loss_f = HeadLoss()
        opt = paddle.optimizer.Adam(parameters=net.parameters(), learning_rate=0.0003)
    
        for i in range(80):
            out, state = net(x[:, :-1], 0)
            loss, av_ac = loss_f(out, x[:, 1:])
    
            print(i, loss.item(), av_ac)
    
            opt.clear_grad()
            loss.backward()
            opt.step()
    
        # 解码,验证
        net.eval()
    
        out, _ = net.sample(xx[:, :-1], 0)
        print((out == xx[:, 1:]).numpy().mean())
        out, _ = net.sample(xx[:, :-2], 0)
        print((out[:, -1:] == xx[:, -2:-1]).numpy().mean())
        out, _ = net.sample(xx[:, :-1], 0)
        print((out[:, -1:] == xx[:, -1:]).numpy().mean())
        for i in range(50):
            print("超长依赖检验")
            out, _ = net.sample(xx[:, :i + 30], 0)
            print((out[:, -1:] == xx[:, i + 30:i + 31]).numpy().mean())
    
        out, _ = net(xx[:, :-1], 0)
        loss, av_ac = loss_f(out, xx[:, 1:])
        print(loss.item(), av_ac)
    
    
    # 进行模型训练和预测
    if __name__ == '__main__':
        emheading_train_and_sample()
    
    # if __name__ == '__main__':
    #     # demo()
    #     net = ReNet(12935, 12, 8,8)
    #     net(paddle.randint(1, 123, [3, 23]), 0)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
  • 相关阅读:
    allatori8.0文档翻译-第二步-基础应用混淆
    猿创征文 第二季|业务总结 #「笔耕不辍」--生命不息,写作不止#
    kubernetes1.18集群安装实战
    Harmony 复杂图形自绘制
    Mstar 848应用图标/遥控器的点击声音
    关于 SAP HANA 数据库的死锁问题(deadlock)
    JVM(二十二)—— 垃圾回收器(二)CMS垃圾回收器
    《最新出炉》系列入门篇-Python+Playwright自动化测试-15-playwright处理浏览器多窗口切换
    Python变量
    【coding加油站】人事管理系统---毕设
  • 原文地址:https://blog.csdn.net/weixin_32759777/article/details/132675187