读了本文,你可以实现从云端利用DNN模型进行训练,模型保存.h5格式(基于keras)或是saved model格式(tf2.0版本),模型转化为tflite,利用android studio 编写java接口程序,实现模型最终的推理预测,并利用studio自带的手机模拟器,将推理结果显示到手机上,最终的效果如下。

实现步骤:
1、下载MPG数据集

2、利用tf2.0实现云端训练,生成mymodel.h5或者.savedmodel目录
可以参考这篇帖子,写的相对清晰
时序信号的模型使用tflite的示例_sinat_18131557的博客-CSDN博客_tflite文件怎么打开
3、转化为tflite
参考代码:以.h5为例,其他类同
mymodel = load_model('mymodel.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(mymodel)
converter.post_training_quantize = True
tflite_model = converter.convert()
open('converted_model.tflite', 'wb').write(tflite_model)
4、利用android studio打包成apk,完成手机端推理预测
这一步对于仅熟悉云端的同学来说很陌生,因为手机安卓端是不一样的,因此可以打包成apk实现,本步骤写的稍微详细些
4.1 利用android studio新建工程,设置项参考下图



4.2 添加一个可以显示预测结果的控件
在Android -> app -> res -> layout -> activity_main.xml android:id="@+id/result"

4.3 将转化完成的tflite模型,放到app-src-main-assets目录下;

4.4 修改如下文件,此步骤很关键:
修改app下的 Gradle Scripts -> build.gradle,注意是Module:My_Application.app这个,添加依赖项:
implementation 'org.tensorflow:tensorflow-lite:+'

同时添加如下代码:
aaptOptions {
noCompress "tflite"
}

完整build.gradle代码,供参考不一定完全相同:
plugins {
id 'com.android.application'
}
android {
compileSdkVersion 33
buildToolsVersion "33.0.0"
defaultConfig {
applicationId "com.example.myapplication"
minSdkVersion 26
targetSdkVersion 33
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
aaptOptions {
noCompress "tflite"
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
buildFeatures {
viewBinding true
}
sourceSets {
main {
assets {
srcDirs 'src\\main\\assets'
}
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.2.0'
implementation 'com.google.android.material:material:1.2.1'
implementation 'androidx.constraintlayout:constraintlayout:2.0.1'
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
implementation 'org.tensorflow:tensorflow-lite:+'
}
4.5 修改app -> java -> com.example.myapplication -> MainActivity,该步骤主要是加载tflite模型,定义输入输出,调用模型推理,并在安卓手机模拟器上显示推理结果:

MainActivity的完整代码如下,主要改动以黑色加粗标出:
package com.example.myapplication;
import androidx.appcompat.app.AppCompatActivity;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
import android.widget.Toast;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import org.tensorflow.lite.Interpreter;
public class MainActivity extends AppCompatActivity {
private static final String TAG = "Test";
private Interpreter tflite;
private Context mContext;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
loadModule();
}
private void loadModule() {
String model = "converted_model";//模型的名字
try {
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
options.setUseNNAPI(true);
options.setAllowFp16PrecisionForFp32(true);
// 加载模型文件
tflite = new Interpreter(loadModelFile(model), options);
Toast.makeText(MainActivity.this, model + " model load success", Toast.LENGTH_SHORT).show();
// 调用Test函数
test();
} catch (IOException e) {
Toast.makeText(MainActivity.this, model + " model load fail", Toast.LENGTH_SHORT).show();
e.printStackTrace();
}
}
// 加载模型文件的函数
private MappedByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
public void test() {
Log.d(TAG, "-----test()----");
try {
// 推理预测数据初始化
// 30.02904
float[] inputData = new float[]{-0.869348F, -1.009459F, -0.784052F, -1.025303F, -0.379759F,
-0.516397F, 0.774676F, -0.465148F, -0.495225F};
// 模型输出,定义需要与模型的输出一致,上一节有检查输入输出
float[][] labelProbArray = new float[1][1];
// 这里需要使用ByteBuffer的形式进行输入,输入的数据是(1,400)个float的数
// 运行模型的predict功能
tflite.run(inputData, labelProbArray);
// 根据模型的定义,输出相应的信息
String result;
result = labelProbArray[0][0] + "";
Log.d("test", "setText = " + result);
TextView tv = findViewById(R.id.result);
tv.setText(result);
} catch (Exception e) {
e.printStackTrace();
}
}
}
4.6 run app即可,会生成apk


4.5 最终显示结果,结束

错误处理:
1、Apps targeting Android 12 and higher are required to specify an explicit value for `android:exported` when the corresponding component has an intent filter defined. See https://developer.android.com/guide/topics/manifest/activity-element#exported for details
如下:

修改:app -> manifests -> AndroidManifest.xml ,在添加如下代码
android:exported="true"
