• TensorFlow学习:在web前端如何使用Keras 模型


    前言

    在上篇文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调中我们学习了如何使用官方模型,以及使用自己的数据微调模型。

    但是吧,代码一直是跑在Python里,而我本身是做前端开发的。我是很想让它在前端进行浏览器里进行运行。

    谷歌贴心的为我们准备了 TensorFlow.jsTensorFlow.js 是一个 JavaScript 库,用于在浏览器和 Node.js 训练和部署机器学习模型。

    这篇文章我们来学习如何在前端运行模型,模型的话就用上一篇文章里训练的花朵分类模型。

    官方文档:TensorFlow.js 官方文档

    注: 下面是我的采坑心得,我这是第一次学习,第一次搞。你要是按照我的步骤遇到了其他问题,不要问我,我也不会。

    建议按顺序观看,这是一个小系列,适合像我这样的初学者入门

    配置环境:windows环境下tensorflow安装

    图片分类案例学习:TensorFlow案例学习:对服装图像进行分类

    使用官方模型,并进行微调:TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调

    将模型转换,在前端使用:TensorFlow学习:在web前端如何使用Keras 模型

    学习

    处理模型

    Keras 模型(通常通过 Python API 创建)可能保存成多种格式之一。完整模型格式可以转换成 TensorFlow.js Layers 格式,这种格式可以直接加载到 TensorFlow.js 中进行推断或进一步训练。

    目标 TensorFlow.js Layers 格式是一个包含 model.json 文件和一组二进制格式的分片权重文件的目录。model.json 文件包含模型拓扑(又称“架构”或“计算图”:它是对层及其连接方式的描述)和权重文件的清单。

    我们上一篇文章训练出的模型就是Keras 模型,这里需要对其进行转换。

    安装

    pip install tensorflowjs
    
    • 1

    就是这一步上来就被搞惨了。
    在这里插入图片描述
    最开始下载,还没下载多少就超时了,直接下载不了。后来查到可以使用国内镜像下载

     pip install tensorflowjs  -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    • 1

    下载速度是变快了很多,结果下载到最后又来一个依赖冲突,最后又下载失败了。最终解决完这个问题是因为在web端使用,也需要下载

    npm install @tensorflow/tfjs
    
    • 1

    当时我在想这两个是不是一个东西啊,问了一下gpt,npm下载的这个还真的可以用来进行模型转换。

    这里我建议,即使你pip下载成功了,最好还是使用npm下载的这个进行模型转换。因为这样可以保证tensorflowjs版本一致,避免因为版本问题导致最后使用时又出问题

    转换

    这个也是个坑啊,文档是这样说的
    在这里插入图片描述
    但是在上一篇文章中,我最后保存的不是.h5格式啊,然后又回去跑模型,最后model.save('my_model.h5'),将模型保存为.h5格式。再然后转换模型

    tensorflowjs_converter --input_format=keras flower_model.h5 flower_js_model
    
    • 1

    在这里插入图片描述
    在这里插入图片描述
    看样子是成功了,结果还真没成功,在前端加载时又报错了。没办法,百度查、翻文档。然后看见了这个
    在这里插入图片描述
    还真需要用这个,不过上面的代码有点问题,不需要有\ 符号。正确代码应该是

    tensorflowjs_converter --input_format=tf_saved_model   flower_model web_model
    
    • 1

    这里要注意:

    • 我们还是使用的npm下载的依赖,不是pip下载的依赖
    • --input_format=tf_saved_model,指定输入格式
    • flower_model web_model这是两个路径,前面的是模型的逻辑,后面那个是转换完成后的输出路径
    • 这里加载的模型不是.h5,就是.pb文件所在的文件夹,记住是文件夹,不是目录

    总之就是将flower_model下的模型进行转换,将转换后的模型输出到web_model目录下
    在这里插入图片描述

    在前端使用

    这里要特别注意对输入图片的处理,一开始就是因为输入图片处理的不正确,导致模型在预测时结果不正确。后来各种查资料,才解决,使用代码如下:

    <template>
      <div class="page-container">
        <div class="first-title">
          官方文档:
          <a href="https://tensorflow.google.cn/js/models?hl=zh-cn"
            >https://tensorflow.google.cn/js/models?hl=zh-cn</a
          >
        </div>
        <div class="img-list">
          <img
            v-for="img in imageList"
            :key="img.name"
            :src="img.url"
            :id="img.name"
            :class="activeImg == img.name ? 'img-item img-item-active' : 'img-item'"
            @click="changeImg(img)"
          />
        </div>
        <div style="margin-top: 20px">结果是:{{ result }}</div>
      </div>
    </template>
    
    <script setup>
    import { ref, onMounted } from "vue";
    import * as tf from "@tensorflow/tfjs";
    
    // 图片
    const imageList = ref([]);
    // 当前选中的图片
    const activeImg = ref("f1");
    // 结果
    const result = ref("");
    // 图片列表
    const IMAGES = [
      {
        name: "f1",
        url: "../assets/f1.jpg",
      },
      {
        name: "f2",
        url: "../assets/f2.jpg",
      },
      {
        name: "f3",
        url: "../assets/f3.jpg",
      },
      {
        name: "f4",
        url: "../assets/f4.jpg",
      },
    ];
    const IMAGENET_CLASSES = ["雏菊", "蒲公英", "玫瑰", "向日葵", "郁金香"];
    
    onMounted(() => {
      imageList.value = [];
    
      IMAGES.forEach((item) => {
        import(item.url).then((img) => {
          imageList.value.push({
            name: item.name,
            url: img.default,
          });
        });
      });
    });
    
    // 切换图片
    const changeImg = async (img) => {
      activeImg.value = img.name;
      // 识别图片
      await identify(img.name);
    };
    
    // 识别图片
    const identify = async (id) => {
      const imageElement = await document.getElementById(id);
      console.log("图片", imageElement);
      // 载入模型
      const model = await tf.loadGraphModel("../../public/web_model/model.json");
      console.log("模型:", model);
    
      // 图像预处理
      const imageTensor = preprocessImage(imageElement);
    
      // 对图片进行预测
      const predictions = await model.predict(imageTensor);
    
      console.log("predictions:", predictions);
    
      // 获取预测结果
      const predictedIndex = tf.argMax(predictions, 1).dataSync()[0];
      const predictedLabel = IMAGENET_CLASSES[predictedIndex];
      result.value = predictedLabel;
      console.log("结果:", predictedLabel, predictedIndex);
    };
    
    // 图像预处理
    const preprocessImage = (img) => {
      // 将图像转换为张量对象并将像素值转换为浮点数类型
      const tensor = tf.browser.fromPixels(img).toFloat();
      // 张量的轴上添加一个维度,以适应模型的输入要求
      const expandedDims = tensor.expandDims();
      // 调整图像的尺寸为224x224,尺寸是模型的要求
      const resizedImg = tf.image.resizeBilinear(expandedDims, [224, 224]);
      // 将像素值归一化到范围[0, 1]之间
      const normalizedImg = tf.div(resizedImg, 255.0);
      // 返回归一化后的图像张量
      return normalizedImg;
    };
    </script>
    
    <style lang="scss" scoped>
    .img-list {
      display: flex;
    
      .img-item {
        width: 240px;
        height: 180px;
        border-radius: 5px;
        cursor: pointer;
        padding: 10px;
      }
    
      .img-item-active {
        border: 2px solid red;
      }
    }
    </style>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128

    最终效果
    在这里插入图片描述

  • 相关阅读:
    【HBZ分享】AQS + CAS +LockSupport 实现ReentrantLock的原理
    从电大搜题到上海开放大学,广播电视大学引领学习新风尚
    开源去中心化社交平台-Misskey
    Java 如何复制 List ?
    【递归、搜索与回溯算法】第四节.50. Pow(x, n)和2331. 计算布尔二叉树的值
    药品研发--原料储存管理制度
    【无标题】
    基于jsp+ssm手机综合类门户网站
    C++11 基础知识
    金色传说:SAP-ABAP- PM工单:IW32组件增强
  • 原文地址:https://blog.csdn.net/weixin_41897680/article/details/133760043