• Tensorflow pb模型转tflite,并量化


    一、tensorflow2.x版本pb模型转换tflite及量化

    1、h5模型转tflite,不进行量化

    import tensorflow as tf
    import numpy as np
    from pathlib import Path
    print("TensorFlow version: ", tf.__version__)
    
    model = tf.keras.models.load_model('model.h5')
    
    ### 不量化
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    tflite_model_file = Path("mnist_model_null.tflite")
    tflite_model_file.write_bytes(tflite_model)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input_type = interpreter.get_input_details()[0]['dtype']
    print('input: ', input_type)
    output_type = interpreter.get_output_details()[0]['dtype']
    print('output: ', output_type)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    `

    2、h5模型转tflite,进行动态范围量化 (官方参考代码)

    import tensorflow as tf
    import numpy as np
    from pathlib import Path
    print("TensorFlow version: ", tf.__version__)
    
    model = tf.keras.models.load_model('model.h5')
    ### 动态范围量化
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model_dynamic = converter.convert()
    tflite_model_file = Path("mnist_model_dynamic.tflite")
    tflite_model_file.write_bytes(tflite_model_dynamic)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model_dynamic)
    input_type = interpreter.get_input_details()[0]['dtype']
    print('input: ', input_type)
    output_type = interpreter.get_output_details()[0]['dtype']
    print('output: ', output_type)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    `

    3、h5模型转tflite,进行int8整型量化 (官方参考代码)

    import tensorflow as tf
    import numpy as np
    from pathlib import Path
    print("TensorFlow version: ", tf.__version__)
    
    model = tf.keras.models.load_model('model.h5')
    mnist = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    # Normalize the input image so that each pixel value is between 0 to 1.
    print(type(train_images), train_images.shape)
    train_images = train_images.astype(np.float32) / 255.0
    def representative_data_gen():
      for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
        yield [input_value]
    
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_data_gen
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.uint8
    converter.inference_output_type = tf.uint8
    tflite_model_int8 = converter.convert()
    tflite_model_file = Path("mnist_model_int8.tflite")
    tflite_model_file.write_bytes(tflite_model_int8)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model_int8)
    input_type = interpreter.get_input_details()[0]['dtype']
    print('input: ', input_type)
    output_type = interpreter.get_output_details()[0]['dtype']
    print('output: ', output_type)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    `

    4、h5模型转tflite,进行float16量化 (官方参考代码)

    import tensorflow as tf
    import numpy as np
    from pathlib import Path
    print("TensorFlow version: ", tf.__version__)
    
    model = tf.keras.models.load_model('model.h5')
    
    # float16量化
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_model_float16 = converter.convert()
    tflite_model_file = Path("mnist_model_float16.tflite")
    tflite_model_file.write_bytes(tflite_model_float16)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model_float16)
    input_type = interpreter.get_input_details()[0]['dtype']
    print('input: ', input_type)
    output_type = interpreter.get_output_details()[0]['dtype']
    print('output: ', output_type)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    `

    二、tensorflow2.x版本调用1.x(.compat.v1)pb模型转换tflite及量化 (官方api)

    1、pb模型转tflite,不进行量化

    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
            graph_def_file = '0824.pb',
            input_arrays = ['x_img_g', 'is_training'],
            input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
            output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
    )
    tflite_model = converter.convert()
    open("model_null.tflite", "wb").write(tflite_model)
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input = interpreter.get_input_details()
    print(input)
    output = interpreter.get_output_details()
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    2、pb模型转tflite,进行动态范围量化

    #  动态量化
    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
            graph_def_file = '0824.pb',
            input_arrays = ['x_img_g', 'is_training'],
            input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
            output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
    )
    converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
    tflite_model = converter.convert()
    open("model_dynamic.tflite", "wb").write(tflite_model)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input = interpreter.get_input_details()
    print(input)
    output = interpreter.get_output_details()
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    3、pb模型转tflite,进行int8整型量化

     # 整型量化
    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
            graph_def_file = '0824.pb',
            input_arrays = ['x_img_g', 'is_training'],
            input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
            output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
    )
    converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
    converter.inference_type = tf.int8
    tflite_model = converter.convert()
    open("model_int8.tflite", "wb").write(tflite_model)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input = interpreter.get_input_details()
    print(input)
    output = interpreter.get_output_details()
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    4、pb模型转tflite,进行float16量化

    #  float16量化
    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
            graph_def_file = '0824.pb',
            input_arrays = ['x_img_g', 'is_training'],
            input_shapes = {'x_img_g' : [1, 256, 512, 3], 'is_training' : [1]},
            output_arrays = ['encoder_generator/classifier/SINET_output/BiasAdd']
    )
    converter.quantized_input_stats = {"x_img_g": (0., 1.), "is_training": (0., 1.)}
    converter.inference_type = tf.float16
    tflite_model = converter.convert()
    open("model_float16.tflite", "wb").write(tflite_model)
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    input = interpreter.get_input_details()
    print(input)
    output = interpreter.get_output_details()
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    ·

    三、调用tflite

    import os
    import cv2
    import time
    import numpy as np
    import tensorflow as tf
    from PIL import Image
    
    #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    #os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    
    # A helper function to evaluate the TF Lite model using "test" dataset.
    def evaluate_model(interpreter):
      input_index = interpreter.get_input_details()[0]["index"]
      output_index = interpreter.get_output_details()[0]["index"]
    
      # Run predictions on every image in the "test" dataset.
      prediction_digits = []
      for test_image in test_images:
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)
    
        # Run inference.
        interpreter.invoke()
    
        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)
    
      # Compare prediction results with ground truth labels to calculate accuracy.
      accurate_count = 0
      for index in range(len(prediction_digits)):
        if prediction_digits[index] == test_labels[index]:
          accurate_count += 1
      accuracy = accurate_count * 1.0 / len(prediction_digits)
    
      return accuracy
    
    
    # interpreter = tf.compat.v1.lite.Interpreter(model_path="model_null.tflite")
    interpreter = tf.compat.v1.lite.Interpreter(model_path="model_int8.tflite")
    # interpreter = tf.compat.v1.lite.Interpreter(model_path="model_float16.tflite")
    # interpreter = tf.compat.v1.lite.Interpreter(model_path="model_dynamic.tflite")
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(input_details)
    print(output_details)
    
    test_image = cv2.imread('test.png')                             # (1080, 1920, 3)
    r_w, r_h = 512, 256
    img_data =  cv2.resize(test_image, (r_w, r_h))                  # (256, 512, 3)
    img_data = np.expand_dims(img_data, axis=0).astype(np.int8)
    
    interpreter.set_tensor(input_details[0]['index'], img_data)
    interpreter.set_tensor(input_details[1]['index'], [False])
    t1 = time.time()
    interpreter.invoke()
    t2 = time.time()
    prediction = interpreter.get_tensor(output_details[0]['index'])
    print(t2-t1)
    
    print(prediction.shape)
    prediction = prediction[0]
    print(prediction.shape)
    
    prediction1 = prediction[:,:,0]
    print(prediction1.shape)
    print(np.max(prediction1),np.min(prediction1))
    img = Image.fromarray(prediction1)
    img.show()
    
    prediction2 = prediction[:,:,1]
    print(prediction2.shape)
    print(np.max(prediction2),np.min(prediction2))
    img = Image.fromarray(prediction2)
    img.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82

    `

    四、参考

    1、官方转换教程参考
    2、tensorflow 将.pb文件量化操作为.tflite
    3、tensorflow2转tflite提示OP不支持的解决方案
    4、Tensorflow2 lite 模型量化

  • 相关阅读:
    【云原生】阿里云容器镜像服务产品
    Github 2024-04-22 开源项目日报Top10
    消息队列 RocketMQ 消息重复消费问题(原因及解决)
    每个程序员都应该了解的 10 大隐私计算技术
    c++ || mutable_explicit_volatile关键字
    FPGA之旅设计99例之第十例-----串口上位机模拟OLED屏
    CCF CSP 201403-2 窗口 题解
    基于 LSTM 的分布式能源发电预测(Matlab代码实现)
    【音视频|ALSA】SS528开发板编译Linux内核ALSA驱动、移植alsa-lib、采集与播放usb耳机声音
    C语言折半查找算法及代码实现
  • 原文地址:https://blog.csdn.net/weixin_39715012/article/details/126669968