NLLLossGrad算子
将python侧的checktype迁移到c++侧,需检测total_weight的type和weight_type一致,编译正常,测试时报错,请问checktype的时候是有什么其他的表示方式吗?
python侧迁移前:
c++侧迁移后:
【截图信息】
const std::set
std::map
(void)args.insert({"var_type", var_type});
(void)args.insert({"accum_type", accum_type});
(void)args.insert({"accum_update_type", accum_update_type});
(void)args.insert({"grad_type", grad_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);