• 使用taichi 写了一个任意平台任意显卡推理的Linear


    这东西就是在于任意的显卡都能加速任意模型
    当然如何有人使用taichi写一个卷积那么计算机视觉也能任意显卡加速人工智能
    如果还有人写了个深度学习训练框架那么恭喜AMD,ARM 等任何芯片厂商乐疯

    import taichi as ti
    import numpy as np
    ti.init(arch=ti.vulkan)
    
    
    class Linear():
        def __init__(self, input_size, output_size, weights=None):
            if weights:
                self.weights = weights
            else:
                # self.weights = ti.Matrix([[0] * output_size] * input_size)
                self.weights = ti.Matrix(np.random.random(output_size*input_size).reshape([input_size,output_size]))
    
        @staticmethod
        def taichi_mul(a, b):
            ar, al = a.to_numpy().shape
            br, bl = b.to_numpy().shape
            assert al == br
    
            @ti.kernel
            def mlp() -> ti.types.matrix(ar, bl, dtype=ti.float32):
                return a @ b
    
            return mlp()
    
        def forward(self, a):
            return self.taichi_mul(a, self.weights)
    
        def set_weights(self, weights):
            self.weights = ti.Matrix(weights)
    
    
    l1 = Linear(3, 5)
    print(l1.forward(ti.Matrix([[1] * 3])))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    这段代码实现了一个简单的线性层(Linear layer)类,通过调用 forward 函数可以对输入进行线性变换。其中 Linear 类中的 weights 表示线性层的权重矩阵,可以通过构造函数的输入或者 set_weights 函数进行设置。forward 函数使用 taichi_mul 函数实现了矩阵乘法,并返回了乘积结果。
    
    在 __init__ 函数中,如果 weights 参数被传入,则将其作为权重矩阵;否则通过 np.random.random 函数生成随机的权重矩阵。其中,输入参数 input_size 和 output_size 分别表示输入和输出的特征数。
    
    在 taichi_mul 函数中,输入参数 a 和 b 分别表示两个矩阵,通过 Taichi 提供的 @ 运算符实现了矩阵乘法,返回乘积结果。值得注意的是,Taichi 不能直接操作 Python 中的数据类型,因此在使用 Taichi 前,需要将 Python 中的数据类型转换为 Taichi 中的数据类型,可以调用 to_numpy() 函数将 Taichi 中的数据类型转换为 NumPy 中的数据类型,然后对其进行操作,最后再调用 ti.Matrix() 函数将其转换为 Taichi 中的数据类型。
    
    在主函数中,首先构造了一个输入矩阵 [[1] * 3],然后通过 l1.forward() 函数将其输入到 l1 线性层中,得到线性变换的结果,最后将结果打印输出。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
  • 相关阅读:
    【节能学院】数据机房中智能小母线与列头柜方案的对比分析
    数学基础之曲线拟合
    《一文搞懂IoU发展历程》GIoU、DIoU、CIoU、EIoU、αIoU、SIoU
    Spring的setter方法注入和构造器注入的对比
    网络安全笔记-业务安全
    vue移动端高德地图的使用及实现最简单的地图功能
    比特集训营第一课
    关于#java#的问题,请各位专家解答!
    【0】数学的魅力
    基于SSM的网络财务管理系统设计与实现
  • 原文地址:https://blog.csdn.net/weixin_32759777/article/details/133553012