• 边缘计算:基于tflite实现andriod边缘端回归预测推理实战


    读了本文,你可以实现从云端利用DNN模型进行训练,模型保存.h5格式(基于keras)或是saved model格式(tf2.0版本),模型转化为tflite,利用android studio 编写java接口程序,实现模型最终的推理预测,并利用studio自带的手机模拟器,将推理结果显示到手机上,最终的效果如下。

    实现步骤:

    1、下载MPG数据集

    2、利用tf2.0实现云端训练,生成mymodel.h5或者.savedmodel目录

    可以参考这篇帖子,写的相对清晰

    时序信号的模型使用tflite的示例_sinat_18131557的博客-CSDN博客_tflite文件怎么打开

     3、转化为tflite

    参考代码:以.h5为例,其他类同

    mymodel = load_model('mymodel.h5')
    converter = tf.lite.TFLiteConverter.from_keras_model(mymodel)
    converter.post_training_quantize = True
    tflite_model = converter.convert()
    open('converted_model.tflite', 'wb').write(tflite_model)

    4、利用android studio打包成apk,完成手机端推理预测

    这一步对于仅熟悉云端的同学来说很陌生,因为手机安卓端是不一样的,因此可以打包成apk实现,本步骤写的稍微详细些

    4.1 利用android studio新建工程,设置项参考下图

     4.2 添加一个可以显示预测结果的控件

    在Android -> app -> res -> layout -> activity_main.xml
    android:id="@+id/result"
    

    4.3 将转化完成的tflite模型,放到app-src-main-assets目录下;

    4.4 修改如下文件,此步骤很关键: 

     修改app下的 Gradle Scripts -> build.gradle,注意是Module:My_Application.app这个,添加依赖项:

    implementation 'org.tensorflow:tensorflow-lite:+'

     同时添加如下代码:

    aaptOptions {
        noCompress "tflite"
    }

    完整build.gradle代码,供参考不一定完全相同:

    plugins {
        id 'com.android.application'
    }
    
    android {
        compileSdkVersion 33
        buildToolsVersion "33.0.0"
    
        defaultConfig {
            applicationId "com.example.myapplication"
            minSdkVersion 26
            targetSdkVersion 33
            versionCode 1
            versionName "1.0"
    
            testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        }
    
        buildTypes {
            release {
                minifyEnabled false
                proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
            }
        }
    
        aaptOptions {
            noCompress "tflite"
        }
    
        compileOptions {
            sourceCompatibility JavaVersion.VERSION_1_8
            targetCompatibility JavaVersion.VERSION_1_8
        }
        buildFeatures {
            viewBinding true
        }
        sourceSets {
            main {
                assets {
                    srcDirs 'src\\main\\assets'
                }
            }
        }
    }
    
    dependencies {
    
        implementation 'androidx.appcompat:appcompat:1.2.0'
        implementation 'com.google.android.material:material:1.2.1'
        implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
        testImplementation 'junit:junit:4.+'
        androidTestImplementation 'androidx.test.ext:junit:1.1.2'
        androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
        implementation 'org.tensorflow:tensorflow-lite:+'
    }
    

    4.5  修改app -> java -> com.example.myapplication -> MainActivity,该步骤主要是加载tflite模型,定义输入输出,调用模型推理,并在安卓手机模拟器上显示推理结果:

    MainActivity的完整代码如下,主要改动以黑色加粗标出:

    package com.example.myapplication;
    
    import androidx.appcompat.app.AppCompatActivity;
    
    import android.content.Context;
    import android.content.res.AssetFileDescriptor;
    import android.os.Bundle;
    import android.util.Log;
    import android.widget.TextView;
    import android.widget.Toast;
    
    import java.io.FileInputStream;
    import java.io.IOException;
    import java.nio.MappedByteBuffer;
    import java.nio.channels.FileChannel;
    import org.tensorflow.lite.Interpreter;
    
    
    
    public class MainActivity extends AppCompatActivity {
        private static final String TAG = "Test";
        private Interpreter tflite;
        private Context mContext;
        @Override
        protected void onCreate(Bundle savedInstanceState) {
            super.onCreate(savedInstanceState);
            setContentView(R.layout.activity_main);
            loadModule();
        }
    
        private void loadModule() {
            String model = "converted_model";//模型的名字
            try {
                Interpreter.Options options = new Interpreter.Options();
                options.setNumThreads(4);
                options.setUseNNAPI(true);
                options.setAllowFp16PrecisionForFp32(true);
                // 加载模型文件
                tflite = new Interpreter(loadModelFile(model), options);
                Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
                // 调用Test函数
                test();
            } catch (IOException e) {
                Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
                e.printStackTrace();
    
            }
        }
    
        // 加载模型文件的函数
        private MappedByteBuffer loadModelFile(String model) throws IOException {
            AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
            FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
            FileChannel fileChannel = inputStream.getChannel();
            long startOffset = fileDescriptor.getStartOffset();
            long declaredLength = fileDescriptor.getDeclaredLength();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
        }
    
        public void test() {
            Log.d(TAG, "-----test()----");
            try {
    //          推理预测数据初始化
    
    //            30.02904
                float[] inputData = new float[]{-0.869348F, -1.009459F, -0.784052F, -1.025303F, -0.379759F,
                       -0.516397F, 0.774676F, -0.465148F, -0.495225F};
    
    //          模型输出,定义需要与模型的输出一致,上一节有检查输入输出
                float[][] labelProbArray = new float[1][1];
    //             这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数
    
    //            运行模型的predict功能
                tflite.run(inputData, labelProbArray);
    
    //             根据模型的定义,输出相应的信息
                String result;
                result = labelProbArray[0][0] + "";
                Log.d("test", "setText = " + result);
                TextView tv = findViewById(R.id.result);
                tv.setText(result);
    
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

     4.6 run app即可,会生成apk

     4.5 最终显示结果,结束

    错误处理:

    1、Apps targeting Android 12 and higher are required to specify an explicit value for `android:exported` when the corresponding component has an intent filter defined. See https://developer.android.com/guide/topics/manifest/activity-element#exported for details

    如下:

     修改:app -> manifests -> AndroidManifest.xml ,在添加如下代码

    android:exported="true"

  • 相关阅读:
    第二次作业
    QLayout 删除widget中的子控件,父控件大小不能自适应
    An工具介绍之形状工具及渐变变形工具
    CSS特效015:7个小球转圈圈加载效果
    计算机网络-应用层(1)
    食品级接触材料的检测标准有哪些?
    【Linux:环境变量】
    abap字段符号(指针)的用法:FIELD-SYMBOLS
    【修改mysql密码的两种方式】
    CentOS-7-x86_64 iso镜像的安装(Linux操作系统)
  • 原文地址:https://blog.csdn.net/qq_18256855/article/details/126440016