• CK草稿本


    调用流程

      1. 获得op_ptr,ck有个工厂模式:
      const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
      
      • 1
      1. 设置参数,这些参数包括输入输出,以及其他必要的配置
      auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
                              b_device_buf.GetDeviceBuffer(),
                              c_device_buf.GetDeviceBuffer(),
                              M,
                              N,
                              K,
                              StrideA,
                              StrideB,
                              StrideC,
                              a_element_op,
                              b_element_op,
                              c_element_op);
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      1. 获得invoker_ptr:auto invoker_ptr = op_ptr->MakeInvokerPointer();
      1. run:float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
      1. 结果后处理

    Invoker

    • 有一个基类BaseInvoker,定义了赋值拷贝,和Run函数(用于算子运行),以及一个虚析构
      • 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
    • 然后每个算子里面会实现一个Invoker,来实现run的操作
      struct BaseInvoker
      {
          BaseInvoker()                   = default;
          BaseInvoker(const BaseInvoker&) = default;
          BaseInvoker& operator=(const BaseInvoker&) = default;
      
          virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
          {
              return float{0};
          }
      
          virtual ~BaseInvoker() {}
      };
      
      
      struct Invoker : public BaseInvoker
      {
          float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
          {
              // run kernel ....
              // cost time ....
          };
      
          float Run(const BaseArgument* p_arg,
                      const StreamConfig& stream_config = StreamConfig{}) override
          {
              return Run(*dynamic_cast(p_arg), stream_config);
          };
      };
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29

    Argument

    • 同样有个基类BaseArgument,有一个p_workspace_的void指针参数,暂不清楚做啥的
      • 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
    • 而每个Operator中都会定义一个Argument子类,里面存一些输入输出,配置等参数
      struct BaseArgument
      {
          BaseArgument()                    = default;
          BaseArgument(const BaseArgument&) = default;
          BaseArgument& operator=(const BaseArgument&) = default;
      
          virtual ~BaseArgument() {}
      
          void* p_workspace_ = nullptr;
      };
      
      struct Argument : public ck::tensor_operation::device::BaseArgument
      {
          Argument(const Tensor& a_gs_ms_ks,
                      const Tensor& b_gs_ns_ks,
                      Tensor& e_gs_ms_ns,
                      AElementwiseOperation a_element_op,
                      BElementwiseOperation b_element_op,
                      CDEElementwiseOperation cde_element_op)
              : a_gs_ms_ks_{a_gs_ms_ks},
                  b_gs_ns_ks_{b_gs_ns_ks},
                  e_gs_ms_ns_{e_gs_ms_ns},
                  a_element_op_{a_element_op},
                  b_element_op_{b_element_op},
                  cde_element_op_{cde_element_op}
          {
          }
      
          const Tensor& a_gs_ms_ks_;
          const Tensor& b_gs_ns_ks_;
          Tensor& e_gs_ms_ns_;
      
          AElementwiseOperation a_element_op_;
          BElementwiseOperation b_element_op_;
          CDEElementwiseOperation cde_element_op_;
      };
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36

    Operator

    • 基类叫BaseOperator,定义如下函数 都是一些比较通用的基础属性:
      • IsSupportedArgument
      • GetTypeString
      • GetTypeIdName
      • GetTypeIdHashCode
      • GetWorkSpaceSize
      • SetWorkSpacePointer
    • 通常子类中需要有定义:
      • struct Argument/MakeArgumentPointer
      • struct Invoke/MakeInvokerPointer
      struct BaseOperator
      {
          BaseOperator()                    = default;
          BaseOperator(const BaseOperator&) = default;
          BaseOperator& operator=(const BaseOperator&) = default;
      
          virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
          virtual std::string GetTypeString() const { return ""; }
      
          virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
      
          virtual std::string GetTypeIdHashCode() const
          {
              std::ostringstream oss;
      
              oss << std::hex << typeid(*this).hash_code();
      
              return oss.str();
          };
      
          virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
      
          virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
          {
              assert(p_arg);
              p_arg->p_workspace_ = p_workspace;
          }
      
          virtual ~BaseOperator() {}
      };
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30

    DeviceOperationInstanceFactory

    • library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
      • 在这个文件中声明了工厂,也就是:
          template 
          struct DeviceOperationInstanceFactory;
      
      • 1
      • 2
    • library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
      • 这里面有个add_device_operation_instances方法,定义了将op实现加入到vector(instance)中
    • 在这之上,有一些函数是用于添加这些instance的,比如device_gemm_dl_f16_f16_f16_km_kn_mn_instances
      • 位于library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
      • 原理就是把tuple中的元素在add_device_operation_instances中全部加到vector中去
      using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
              // MPerBlock=8, NPerBlock=8
              DeviceGemmDl<.....>,
              DeviceGemmDl<.....>,
              DeviceGemmDl<.....>,
              .....
          >;
      
      void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
          std::vector>>&
              instances)
      {
          add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
      }
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
    • 然后这个函数会在DeviceOperationInstanceFactory中的GetInstances中被调用到,于是就得到了一个vector数组,里面装满了invoke_ptr实现
      • 对于上面这个例子,在这个文件中被调用到:library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp

    案例

    • client_example/01_gemm/gemm.cpp
    • 在这个example中有这样一句代码:
      • 很显然,这是通过工厂类拿到算子实例集合
      using DeviceOp =
          ck::tensor_operation::device::DeviceGemm;
      
      // get device op instances
      const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
          DeviceOp>::GetInstances();
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
    • DeviceGemm这个operator长这样,当然这也是个虚基类,真正的实现实在Impl文件夹中定义的:
      template 
      struct DeviceGemm : public BaseOperator
      {
          virtual std::unique_ptr
          MakeArgumentPointer(const void* p_a,
                              const void* p_b,
                              void* p_c,
                              ck::index_t M,
                              ck::index_t N,
                              ck::index_t K,
                              ck::index_t StrideA,
                              ck::index_t StrideB,
                              ck::index_t StrideC,
                              AElementwiseOperation a_element_op,
                              BElementwiseOperation b_element_op,
                              CElementwiseOperation c_element_op) = 0;
      
          virtual std::unique_ptr MakeInvokerPointer() = 0;
      };
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
    • 然后会在下一级子类中真正实现:
      struct DeviceGemm_Xdl_CShuffle : public DeviceGemm
      ........
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
    • 然后通过工厂类的GetInstances拿到op_ptrs,接下来就是遍历,在for的过程中需要经过:
      • auto argument_ptr = op_ptr->MakeArgumentPointer
      • auto invoker_ptr = op_ptr->MakeInvokerPointer
      • invoker_ptr->Run
    • 这就是这个example干的事儿,实际上在调用的过程中factory应该可以不用,而直接使用实例化的op_ptr

    特有名词

    • 在阅读demo(如gemm.cc)的时候会发现一些特有的名词,如:

      • using F16 = ck::half_t;
      • using Row = ck::tensor_layout::gemm::RowMajor;
      • using Col = ck::tensor_layout::gemm::ColumnMajor;
      • using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    • 有一些比较好理解,如:半精度之类

    • 有一些可以勉强看出来,如layerout是列优先还是行优先(RowMajor/ColumnMajor)

    • 有一些比较抽象,如PassThrough

    以PassThrough为例
    • 这是一个传值操作,代码实现位于:include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
    • 下面展示了一部分可以看到,函数的作用是传值
    struct PassThrough
    {
        template <typename Y, typename X>
        __host__ __device__ void operator()(Y& y, const X& x) const;
    
        template <>
        __host__ __device__ void operator()<double, double>(double& y, const double& x) const
        {
            y = x;
        }
    
        template <>
        __host__ __device__ void operator()<float, float>(float& y, const float& x) const
        {
            y = x;
        }
        ....
    };
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
  • 相关阅读:
    (Java高级教程)第三章Java网络编程-第八节:博客系统搭建(前后端分离)
    ubuntu18.04上安装protubuf3.19.4
    nodejs+vue健身房课程预约评分系统
    ArrayList,LinkedList和Vector的区别
    vue接入高德地图获取经纬度
    [附源码]java毕业设计某公司酬薪管理系统
    8.自定义组件布局和详解Context上下文
    重学SpringBoot3-日志Logging
    高并发下丢失更新的解决方案
    浅谈原型链
  • 原文地址:https://blog.csdn.net/symuamua/article/details/132847231