• 笔尖笔帽检测3:Android实现笔尖笔帽检测算法(含源码 可是实时检测)


    目录

    1. 前言

    2.笔尖笔帽检测方法

    (1)Top-Down(自上而下)方法

    (2)Bottom-Up(自下而上)方法:

    3.笔尖笔帽关键点检测模型训练

    4.笔尖笔帽关键点检测模型Android部署

    (1) 将Pytorch模型转换ONNX模型

    (2) 将ONNX模型转换为TNN模型

    (3) Android端上部署模型

    (4) Android测试效果 

    (5) 运行APP闪退:dlopen failed: library "libomp.so" not found

    5.Android项目源码下载

    6.C++实现笔尖笔帽关键点检测

    7.特别版: 笔尖指尖检测


    1. 前言

    目前在AI智慧教育领域,有一个比较火热的教育产品,即指尖点读或者笔尖点读功能,其核心算法就是通过深度学习的方法获得笔尖或者指尖的位置,在通过OCR识别文本,最后通过TTS(TextToSpeech)将文本转为语音;其中OCR和TTS算法都已经研究非常成熟了,而指尖或者笔尖检测的方法也有一些开源的项目可以参考实现。本项目将实现笔尖笔帽关键点检测算法,其中使用YOLOv5模型实现手部检测(手握着笔目标检测),使用HRNet,LiteHRNet和Mobilenet-v2模型实现笔尖笔帽关键点检测。项目分为数据标注,模型训练和Android部署等多个章节,本篇是项目《笔尖笔帽检测》系列文章之Android实现笔尖笔帽检测算法;为了方便后续模型工程化和Android平台部署,项目支持高精度HRNet检测模型,轻量化模型LiteHRNet和Mobilenet模型训练和测试,并提供Python/C++/Android多个版本;

    轻量化Mobilenet-v2模型在普通Android手机上可以达到实时的检测效果,CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求。下表格给出HRNet,以及轻量化模型LiteHRNet和Mobilenet的计算量和参数量,以及其检测精度。

    模型input-sizeparams(M)GFLOPsAP
    HRNet-w32192×19228.48M5734.05M0.8418
    LiteHRNet18192×1921.10M182.15M0.7469
    Mobilenet-v2192×1922.63M529.25M0.7531

    尊重原创,转载请注明出处https://blog.csdn.net/guyuealian/article/details/134070497

    Android笔尖笔帽关键点检测APP Demo体验(下载):

    https://download.csdn.net/download/guyuealian/88535143

        


    更多项目《笔尖笔帽检测》系列文章请参考:

     


    2.笔尖笔帽检测方法

    笔尖笔帽目标较小,如果直接使用目标检测,很难达到像素级别的检测精度;一般建议使用类似于人体关键点检测的方案。目前主流的关键点方法主要两种:一种是Top-Down(自上而下)方法,另外一种是Bottom-Up(自下而上)方法;

    (1)Top-Down(自上而下)方法

    将手部检测(手握笔的情况)和笔尖笔帽关键点检测分离,在图像上首先进行手部目标检测,定位手部位置;然后crop每一个手部图像,再估计笔尖笔帽关键点;这类方法往往比较慢,但姿态估计准确度较高。目前主流模型主要有CPN,Hourglass,CPM,Alpha Pose,HRNet等。

    (2)Bottom-Up(自下而上)方法:

    先估计图像中所有笔尖笔帽关键点,然后在通过Grouping的方法组合成一个一个实例;因此这类方法在测试推断的时候往往更快速,准确度稍低。典型就是COCO2016年人体关键点检测冠军Open Pose。

    通常来说,Top-Down具有更高的精度,而Bottom-Up具有更快的速度;就目前调研而言, Top-Down的方法研究较多,精度也比Bottom-Up(自下而上)方法高。

    本项目采用Top-Down(自上而下)方法,使用YOLOv5模型实现手部检测(手握笔检测),使用HRNet进行手部关键点检测;也可以简单理解为,先使用YOLOv5定位手握笔的区域位置,再使用HRNet进行笔尖笔帽精细化位置定位。

    本项目基于开源的HRNet进行改进,关于HRNet项目请参考GitHub

    HRNet: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch


    3.笔尖笔帽关键点检测模型训练

    本项目采用Top-Down(自上而下)方法,使用YOLOv5模型实现手部检测(手笔检测),并基于开源的HRNet进行改进实现笔尖笔帽关键点检测;为了方便后续模型工程化和Android平台部署,项目支持轻量化模型LiteHRNet和Mobilenet模型训练和测试,并提供Python/C++/Android多个版本;轻量化Mobilenet-v2模型在普通Android手机上可以达到实时的检测效果,CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求

    关于笔尖笔帽关键点检测模型训练,可参考 :

    笔尖笔帽检测2:Pytorch实现笔尖笔帽检测算法(含训练代码和数据集)

    下表格给出HRNet,以及轻量化模型LiteHRNet和Mobilenet的计算量和参数量,以及其检测精度AP; 高精度检测模型HRNet-w32,AP可以达到0.8418,但其参数量和计算量比较大,不合适在移动端部署;LiteHRNet18和Mobilenet-v2参数量和计算量比较少,合适在移动端部署;虽然LiteHRNet18的理论计算量和参数量比Mobilenet-v2低,但在实际测试中,发现Mobilenet-v2运行速度更快。轻量化Mobilenet-v2模型在普通Android手机上可以达到实时的检测效果,CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求

    模型input-sizeparams(M)GFLOPsAP
    HRNet-w32192×19228.48M5734.05M0.8418
    LiteHRNet18192×1921.10M182.15M0.7469
    Mobilenet-v2192×1922.63M529.25M0.7531

    HRNet-w32参数量和计算量太大,不适合在Android手机部署,本项目Android版本只支持部署LiteHRNet和Mobilenet-v2模型;C++版本可支持部署HRNet-w32,LiteHRNet和Mobilenet-v2模型 


    4.笔尖笔帽关键点检测模型Android部署

    目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署。部署流程可分为四步:训练模型->将模型转换ONNX模型->将ONNX模型转换为TNN模型->Android端上部署TNN模型。

    (1) 将Pytorch模型转换ONNX模型

    训练好Pytorch模型后,我们需要先将模型转换为ONNX模型,以便后续模型部署。

    • 原始Python项目提供转换脚本,你只需要修改model_file和config_file为你模型路径即可
    •  convert_torch_to_onnx.py实现将Pytorch模型转换ONNX模型的脚本
    python libs/convert_tools/convert_torch_to_onnx.py
    1. """
    2. This code is used to convert the pytorch model into an onnx format model.
    3. """
    4. import os
    5. import torch.onnx
    6. from pose.inference import PoseEstimation
    7. from basetrainer.utils.converter import pytorch2onnx
    8. def load_model(config_file, model_file, device="cuda:0"):
    9. pose = PoseEstimation(config_file, model_file, device=device)
    10. model = pose.model
    11. config = pose.config
    12. return model, config
    13. def convert2onnx(config_file, model_file, device="cuda:0", onnx_type="kp"):
    14. """
    15. :param model_file:
    16. :param input_size:
    17. :param device:
    18. :param onnx_type:
    19. :return:
    20. """
    21. model, config = load_model(config_file, model_file, device=device)
    22. model = model.to(device)
    23. model.eval()
    24. model_name = os.path.basename(model_file)[:-len(".pth")]
    25. onnx_file = os.path.join(os.path.dirname(model_file), model_name + ".onnx")
    26. # dummy_input = torch.randn(1, 3, 240, 320).to("cuda")
    27. input_size = tuple(config.MODEL.IMAGE_SIZE) # w,h
    28. input_shape = (1, 3, input_size[1], input_size[0])
    29. pytorch2onnx.convert2onnx(model,
    30. input_shape=input_shape,
    31. input_names=['input'],
    32. output_names=['output'],
    33. onnx_file=onnx_file,
    34. opset_version=11)
    35. if __name__ == "__main__":
    36. config_file = "../../work_space/pen/mobilenet_v2_2_192_192_custom_coco_20231114_000651_3262/mobilenetv2_192_192.yaml"
    37. model_file = "../../work_space/pen/mobilenet_v2_2_192_192_custom_coco_20231114_000651_3262/model/model_199_0.7518.pth"
    38. convert2onnx(config_file, model_file)

    (2) 将ONNX模型转换为TNN模型

    目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署

    TNN转换工具:

    ​​

    (3) Android端上部署模型

    项目实现了Android版本的手部检测(手握着笔)和笔尖笔帽关键点检测Demo,部署框架采用TNN,支持多线程CPU和GPU加速推理,在普通手机上可以实时处理。项目Android源码,核心算法均采用C++实现,上层通过JNI接口调用。

    如果你想在这个Android Demo部署你自己训练的分类模型,你可将训练好的Pytorch模型转换ONNX ,再转换成TNN模型,然后把TNN模型代替你模型即可。 

    HRNet-w32参数量和计算量太大,不适合在Android手机部署,本项目Android版本只支持部署LiteHRNet和Mobilenet-v2模型;C++版本可支持部署HRNet-w32,LiteHRNet和Mobilenet-v2模型 

    • 这是项目Android源码JNI接口 ,Java部分
    1. package com.cv.tnn.model;
    2. import android.graphics.Bitmap;
    3. public class Detector {
    4. static {
    5. System.loadLibrary("tnn_wrapper");
    6. }
    7. /***
    8. * 初始化检测模型
    9. * @param dets_model: 检测模型(不含后缀名)
    10. * @param pose_model: 识别模型(不含后缀名)
    11. * @param root:模型文件的根目录,放在assets文件夹下
    12. * @param model_type:模型类型
    13. * @param num_thread:开启线程数
    14. * @param useGPU:是否开启GPU进行加速
    15. */
    16. public static native void init(String dets_model, String pose_model, String root, int model_type, int num_thread, boolean useGPU);
    17. /***
    18. * 返回检测和识别结果
    19. * @param bitmap 图像(bitmap),ARGB_8888格式
    20. * @param score_thresh:置信度阈值
    21. * @param iou_thresh: IOU阈值
    22. * @param pose_thresh: 关键点阈值
    23. * @return
    24. */
    25. public static native FrameInfo[] detect(Bitmap bitmap, float score_thresh, float iou_thresh, float pose_thresh);
    26. }
    • 这是Android项目源码JNI接口 ,C++部分
    1. #include
    2. #include
    3. #include
    4. #include "src/yolov5.h"
    5. #include "src/pose_detector.h"
    6. #include "src/Types.h"
    7. #include "debug.h"
    8. #include "android_utils.h"
    9. #include "opencv2/opencv.hpp"
    10. #include "file_utils.h"
    11. using namespace dl;
    12. using namespace vision;
    13. static YOLOv5 *detector = nullptr;
    14. static PoseDetector *pose = nullptr;
    15. JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
    16. return JNI_VERSION_1_6;
    17. }
    18. JNIEXPORT void JNI_OnUnload(JavaVM *vm, void *reserved) {
    19. }
    20. extern "C"
    21. JNIEXPORT void JNICALL
    22. Java_com_cv_tnn_model_Detector_init(JNIEnv *env,
    23. jclass clazz,
    24. jstring dets_model,
    25. jstring pose_model,
    26. jstring root,
    27. jint model_type,
    28. jint num_thread,
    29. jboolean use_gpu) {
    30. if (detector != nullptr) {
    31. delete detector;
    32. detector = nullptr;
    33. }
    34. std::string parent = env->GetStringUTFChars(root, 0);
    35. std::string dets_model_ = env->GetStringUTFChars(dets_model, 0);
    36. std::string pose_model_ = env->GetStringUTFChars(pose_model, 0);
    37. string dets_model_file = path_joint(parent, dets_model_ + ".tnnmodel");
    38. string dets_proto_file = path_joint(parent, dets_model_ + ".tnnproto");
    39. string pose_model_file = path_joint(parent, pose_model_ + ".tnnmodel");
    40. string pose_proto_file = path_joint(parent, pose_model_ + ".tnnproto");
    41. DeviceType device = use_gpu ? GPU : CPU;
    42. LOGW("parent : %s", parent.c_str());
    43. LOGW("useGPU : %d", use_gpu);
    44. LOGW("device_type: %d", device);
    45. LOGW("model_type : %d", model_type);
    46. LOGW("num_thread : %d", num_thread);
    47. YOLOv5Param model_param = YOLOv5s05_320;//模型参数
    48. detector = new YOLOv5(dets_model_file,
    49. dets_proto_file,
    50. model_param,
    51. num_thread,
    52. device);
    53. PoseParam pose_param = POSE_MODEL_TYPE[model_type];//模型类型
    54. pose = new PoseDetector(pose_model_file,
    55. pose_proto_file,
    56. pose_param,
    57. num_thread,
    58. device);
    59. }
    60. extern "C"
    61. JNIEXPORT jobjectArray JNICALL
    62. Java_com_cv_tnn_model_Detector_detect(JNIEnv *env, jclass clazz, jobject bitmap,
    63. jfloat score_thresh, jfloat iou_thresh, jfloat pose_thresh) {
    64. cv::Mat bgr;
    65. BitmapToMatrix(env, bitmap, bgr);
    66. int src_h = bgr.rows;
    67. int src_w = bgr.cols;
    68. // 检测区域为整张图片的大小
    69. FrameInfo resultInfo;
    70. // 开始检测
    71. if (detector != nullptr) {
    72. detector->detect(bgr, &resultInfo, score_thresh, iou_thresh);
    73. } else {
    74. ObjectInfo objectInfo;
    75. objectInfo.x1 = 0;
    76. objectInfo.y1 = 0;
    77. objectInfo.x2 = (float) src_w;
    78. objectInfo.y2 = (float) src_h;
    79. objectInfo.label = 0;
    80. resultInfo.info.push_back(objectInfo);
    81. }
    82. int nums = resultInfo.info.size();
    83. LOGW("object nums: %d\n", nums);
    84. if (nums > 0) {
    85. // 开始检测
    86. pose->detect(bgr, &resultInfo, pose_thresh);
    87. // 可视化代码
    88. //classifier->visualizeResult(bgr, &resultInfo);
    89. }
    90. //cv::cvtColor(bgr, bgr, cv::COLOR_BGR2RGB);
    91. //MatrixToBitmap(env, bgr, dst_bitmap);
    92. auto BoxInfo = env->FindClass("com/cv/tnn/model/FrameInfo");
    93. auto init_id = env->GetMethodID(BoxInfo, "", "()V");
    94. auto box_id = env->GetMethodID(BoxInfo, "addBox", "(FFFFIF)V");
    95. auto ky_id = env->GetMethodID(BoxInfo, "addKeyPoint", "(FFF)V");
    96. jobjectArray ret = env->NewObjectArray(resultInfo.info.size(), BoxInfo, nullptr);
    97. for (int i = 0; i < nums; ++i) {
    98. auto info = resultInfo.info[i];
    99. env->PushLocalFrame(1);
    100. //jobject obj = env->AllocObject(BoxInfo);
    101. jobject obj = env->NewObject(BoxInfo, init_id);
    102. // set bbox
    103. //LOGW("rect:[%f,%f,%f,%f] label:%d,score:%f \n", info.rect.x,info.rect.y, info.rect.w, info.rect.h, 0, 1.0f);
    104. env->CallVoidMethod(obj, box_id, info.x1, info.y1, info.x2 - info.x1, info.y2 - info.y1,
    105. info.label, info.score);
    106. // set keypoint
    107. for (const auto &kps : info.keypoints) {
    108. //LOGW("point:[%f,%f] score:%f \n", lm.point.x, lm.point.y, lm.score);
    109. env->CallVoidMethod(obj, ky_id, (float) kps.point.x, (float) kps.point.y,
    110. (float) kps.score);
    111. }
    112. obj = env->PopLocalFrame(obj);
    113. env->SetObjectArrayElement(ret, i, obj);
    114. }
    115. return ret;
    116. }

    (4) Android测试效果 

    Android Demo在普通手机CPU/GPU上可以达到实时检测效果;CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求。

    Android笔尖笔帽关键点检测APP Demo体验(下载):

    https://download.csdn.net/download/guyuealian/88535143

           

         

    (5) 运行APP闪退:dlopen failed: library "libomp.so" not found

    参考解决方法:
    解决dlopen failed: library “libomp.so“ not found_PKing666666的博客-CSDN博客_dlopen failed

     Android SDK和NDK相关版本信息,请参考: 

     


    5.Android项目源码下载

    Android项目源码下载地址:Android实现笔尖笔帽检测算法(含源码 可是实时检测)

    整套Android项目源码内容包含:

    1. Android Demo源码支持YOLOv5手部检测(手握笔检测)
    2. Android Demo源码支持轻量化模型LiteHRNet和Mobilenet-v2笔尖笔帽关键点检测
    3. Android Demo在普通手机CPU/GPU上可以实时检测,CPU约50ms,GPU约30ms左右
    4. Android Demo支持图片,视频,摄像头测试
    5. 所有依赖库都已经配置好,可直接build运行,若运行出现闪退,请参考dlopen failed: library “libomp.so“ not found 解决。


    6.C++实现笔尖笔帽关键点检测


    7.特别版: 笔尖指尖检测

    碍于篇幅,本文章只实现了笔尖笔帽关键点检测;实质上,要实现指尖点读或者笔尖点读功能,我们可能并不需要笔帽检测,而是需要实现笔尖+指尖检测功能;其实现方法与笔尖笔帽关键点检测类似。

    下面是成功产品落地应用的笔尖+指尖检测算法Demo,其检测精度和速度性能都比笔尖笔帽检测的效果要好。

    如果你需要笔尖+指尖检测算法,可在公众号咨询联系

  • 相关阅读:
    巧用CSS3之雨伞
    记录nacos2.0+使用nginx代理出现的问题
    ch03:算数运算(长沙师范学院)
    C2基础设施威胁情报对抗策略
    深入探索Python开发:打造高质量技术的实战之路
    Python读写文本、图片、xml
    为什么要学MySQL数据库,它有什么用?
    断网情况下,华为init接口持续调用,导致手机耗电严重
    eyb:工资账套页面设计到聊天数据显示(五)
    mysql复习
  • 原文地址:https://blog.csdn.net/guyuealian/article/details/134070497