const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
struct BaseInvoker
{
BaseInvoker() = default;
BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {}
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// run kernel ....
// cost time ....
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast(p_arg), stream_config);
};
};
struct BaseArgument
{
BaseArgument() = default;
BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
};
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor& a_gs_ms_ks,
const Tensor& b_gs_ns_ks,
Tensor& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor& a_gs_ms_ks_;
const Tensor& b_gs_ns_ks_;
Tensor& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
struct BaseOperator
{
BaseOperator() = default;
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::string GetTypeIdHashCode() const
{
std::ostringstream oss;
oss << std::hex << typeid(*this).hash_code();
return oss.str();
};
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
virtual ~BaseOperator() {}
};
template
struct DeviceOperationInstanceFactory;
using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// MPerBlock=8, NPerBlock=8
DeviceGemmDl<.....>,
DeviceGemmDl<.....>,
DeviceGemmDl<.....>,
.....
>;
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector>>&
instances)
{
add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
}
using DeviceOp =
ck::tensor_operation::device::DeviceGemm;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
template
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr MakeInvokerPointer() = 0;
};
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm
........
在阅读demo(如gemm.cc)的时候会发现一些特有的名词,如:
有一些比较好理解,如:半精度之类
有一些可以勉强看出来,如layerout是列优先还是行优先(RowMajor/ColumnMajor)
有一些比较抽象,如PassThrough
struct PassThrough
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
....
};