• TensorFlow系列:第五讲:移动端部署模型


    项目地址:https://github.com/LionJackson/imageClassification
    Flutter项目地址:https://github.com/LionJackson/flutter_image

    一. 模型转换

    编写tflite模型工具类:

    import os
    
    import PIL
    import tensorflow as tf
    import keras
    import numpy as np
    from PIL.Image import Image
    from matplotlib import pyplot as plt
    
    from utils.dataset_loader import DatasetLoader
    from utils.utils import Utils
    
    """
    tflite模型工具类
    """
    
    
    class TFLiteUtil:
        def __init__(self, saved_model_dir, path_url):
            self.save_model_dir = saved_model_dir
            self.path_url = path_url
    
        # 训练的模型生成标签列表
        def get_folder_names(self):
            folder_names = []
            for root, dirs, files in os.walk(self.path_url + '/train'):
                for dir_name in dirs:
                    folder_names.append(dir_name)
    
            with open(self.save_model_dir + '.label', 'w') as file:
                for name in folder_names:
                    file.write(name + '\n')
            return folder_names
    
        # 模型转成tflite格式
        def convert_tflite(self):
            self.get_folder_names()
            converter = tf.lite.TFLiteConverter.from_saved_model(self.save_model_dir)
            tflite_model = converter.convert()
    
            # 将转换后的 TFLite 模型保存为文件
            with open(self.save_model_dir + '.tflite', 'wb') as f:
                f.write(tflite_model)
    
            print("转换成功,已保存为 tflite")
    
        # 加载keras并转成tflite
        def convert_model_tflite(self):
            self.get_folder_names()
            model = keras.models.load_model(self.save_model_dir + ".keras")
            converter = tf.lite.TFLiteConverter.from_keras_model(model)
            converter.optimizations = [tf.lite.Optimize.DEFAULT]
            converter.target_spec.supported_types = [tf.float16]
            tflite_model = converter.convert()
            # 将转换后的 TFLite 模型保存为文件
            with open(self.save_model_dir + '.tflite', 'wb') as f:
                f.write(tflite_model)
    
            print("转换成功(model),已保存为 tflite")
    
        # 批量识别 进行可视化显示
        def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):
            dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)
            train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()
    
            interpreter = tf.lite.Interpreter(self.save_model_dir + '.tflite')
            interpreter.allocate_tensors()
            # 获取输入和输出张量的信息
            input_details = interpreter.get_input_details()
            output_details = interpreter.get_output_details()
    
            plt.figure(figsize=(10, 10))
            for images, labels in test_ds.take(1):
                outputs = []
                for img in images:
                    img_expanded = np.expand_dims(img, axis=0)
                    interpreter.set_tensor(input_details[0]['index'], img_expanded)
                    interpreter.invoke()
                    output = interpreter.get_tensor(output_details[0]['index'])
                    outputs.append(output)
    
                for i in range(num_images):
                    plt.subplot(5, 5, i + 1)
                    image = np.array(images[i]).astype("uint8")
                    plt.imshow(image)
                    index = int(np.argmax(outputs[i]))
                    prediction = outputs[i][0][index]
                    percentage_str = "{:.2f}%".format(prediction * 100)
                    plt.title(f"{class_names[index]}: {percentage_str}")
                    plt.axis("off")
            plt.subplots_adjust(hspace=0.5, wspace=0.5)
            plt.show()
    
        # 查看tflite模型信息
        def tflite_analyzer(self):
            # 加载 TFLite 模型
            interpreter = tf.lite.Interpreter(model_path=self.save_model_dir + '.tflite')
            interpreter.allocate_tensors()
    
            # 获取输入和输出的详细信息
            input_details = interpreter.get_input_details()
            output_details = interpreter.get_output_details()
    
            # 打印输入和输出的详细信息
            print("Input Details:")
            for detail in input_details:
                print(detail)
    
            print("\nOutput Details:")
            for detail in output_details:
                print(detail)
    
            # 列出所有使用的算子
            tensor_details = interpreter.get_tensor_details()
    
            print("\nTensor Details:")
            for tensor_detail in tensor_details:
                print("Index:", tensor_detail['index'])
                print("Name:", tensor_detail['name'])
                print("Shape:", tensor_detail['shape'])
                print("Shape Signature:", tensor_detail['shape_signature'])
                print("dtype:", tensor_detail['dtype'])
                print("Quantization:", tensor_detail['quantization'])
                print("Quantization Parameters:", tensor_detail['quantization_parameters'])
                print("Sparsity Parameters:", tensor_detail['sparsity_parameters'])
                print()
    
    

    引用工具类:

    if __name__ == '__main__':
        # train()
        # model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)
        # model_util.batch_evaluation()
        tflite_util = TFLiteUtil(SAVED_MODEL_DIR, PATH_URL)
        tflite_util.convert_tflite()
        tflite_util.tflite_analyzer()
        tflite_util.batch_evaluation()
    
    

    此时会生成tflite模型文件:

    在这里插入图片描述

    二. 使用模型

    创建flutter项目,引入以下库:

      image: ^4.0.17
      path: ^1.8.3
      path_provider: ^2.0.15
      image_picker: ^0.8.8
      tflite_flutter: ^0.10.4
      camera: ^0.10.5+2
    

    把模型文件拷贝到项目中:

    在这里插入图片描述
    核心代码:

    
    
    import 'dart:developer';
    import 'dart:io';
    import 'dart:isolate';
    
    import 'package:camera/camera.dart';
    import 'package:flutter/services.dart';
    import 'package:image/image.dart';
    import 'package:tflite_flutter/tflite_flutter.dart';
    
    import 'isolate_inference.dart';
    
    class ImageClassificationHelper {
      static const modelPath = 'assets/models/fruits.tflite';
      static const labelsPath = 'assets/models/fruits.label';
    
      late final Interpreter interpreter;
      late final List<String> labels;
      late final IsolateInference isolateInference;
      late Tensor inputTensor;
      late Tensor outputTensor;
    
      // Load model
      Future<void> _loadModel() async {
        final options = InterpreterOptions();
    
        // Use XNNPACK Delegate
        if (Platform.isAndroid) {
          options.addDelegate(XNNPackDelegate());
        }
    
        // Use GPU Delegate
        // doesn't work on emulator
        // if (Platform.isAndroid) {
        //   options.addDelegate(GpuDelegateV2());
        // }
    
        // Use Metal Delegate
        if (Platform.isIOS) {
          options.addDelegate(GpuDelegate());
        }
    
        // Load model from assets
        interpreter = await Interpreter.fromAsset(modelPath, options: options);
        // Get tensor input shape [1, 224, 224, 3]
        inputTensor = interpreter.getInputTensors().first;
        // Get tensor output shape [1, 1001]
        outputTensor = interpreter.getOutputTensors().first;
    
        log('Interpreter loaded successfully');
      }
    
      // Load labels from assets
      Future<void> _loadLabels() async {
        final labelTxt = await rootBundle.loadString(labelsPath);
        labels = labelTxt.split('\n');
      }
    
      Future<void> initHelper() async {
        _loadLabels();
        _loadModel();
        isolateInference = IsolateInference();
        await isolateInference.start();
      }
    
      Future<Map<String, double>> _inference(InferenceModel inferenceModel) async {
        ReceivePort responsePort = ReceivePort();
        isolateInference.sendPort
            .send(inferenceModel..responsePort = responsePort.sendPort);
        // get inference result.
        var results = await responsePort.first;
        return results;
      }
    
      // inference camera frame
      Future<Map<String, double>> inferenceCameraFrame(
          CameraImage cameraImage) async {
        var isolateModel = InferenceModel(cameraImage, null, interpreter.address,
            labels, inputTensor.shape, outputTensor.shape);
        return _inference(isolateModel);
      }
    
      // inference still image
      Future<Map<String, double>> inferenceImage(Image image) async {
        var isolateModel = InferenceModel(null, image, interpreter.address, labels,
            inputTensor.shape, outputTensor.shape);
        return _inference(isolateModel);
      }
    
      Future<void> close() async {
        isolateInference.close();
      }
    }
    
    

    页面部分:

    在这里插入图片描述

  • 相关阅读:
    通过java爬取动态网页
    SSM篇目录总结
    input输入框小写字母自动转换成大写字母的几种方式
    Java日志系统之Logback
    3-3主机发现-四层发现
    谷粒商城 (十九) --------- 商品服务 概念 SPU & SKU
    WiFi 四次握手&Omnipeek抓包
    上半年绩效差,「营销分析报告」无从下手,这套模板领导一看就懂
    Apache Doris以Routine Load方式流式的导入Kafka数据
    Linux服务器部署Nginx并发布web项目
  • 原文地址:https://blog.csdn.net/wang_yong_hui_1234/article/details/140355816