• cnn 实现图片识别


    在入门之后需要对机器学习的一些思维和方法体验下2333

    一个github上的一些源码变式,体验一下现在深度学习识别图片的速度之快,准确度之好。

    使用了Tensorflow里的高级神经网络库,使用其keras.applications模块获取在ILSVRC竞赛中获胜的多个卷积网络模型,可识别物体量从10类增加到1001类,可为:狗熊 椅子 汽车 键盘 箱子 婴儿床 旗杆iPod播放器 轮船 面包车 项链 降落伞 桌子 钱包 球拍 步枪等等
    接着导入ResNet50网络模型进行处理,主要图像数据处理函数如下:


    image.img_to_array:将PIL格式的图像转换为numpy数组。


    np.expand_dims:将我们的(3,224,224)大小的图像转换为(1,3,224,224)。因为model.predict函数需要4维数组作为输入,其中第4维为每批预测图像的数量。这也就是说,我们可以一次性分类多个图像。


    preprocess_input:使用训练数据集中的平均通道值对图像数据进行零值处理,即使得图像所有点的和为0。这是非常重要的步骤,如果跳过,将大大影响实际预测效果。这个步骤称为数据归一化。


    model.predict:对我们的数据分批处理并返回预测值。


    decode_predictions:采用与model.predict函数相同的编码标签,并从ImageNet ILSVRC集返回可读的标签。
    然后通过调用官方api,得到了这个classfier,可以读取同目录下images文件夹里的指定命名的图片,这个分类器在我的笔记本tensorflow中训练差不多需要两个小时的时间,训练好之后实际识别的速率快达1~2分钟一张图,准确率高达90%以上

    代码:

    GUI:(PyQt5)

    from PyQt5 import QtWidgets
    from PyQt5.QtWidgets import QFileDialog
    from PyQt5 import QtCore, QtGui
    import classify
    class MyWindow(QtWidgets.QWidget):
        def __init__(self):
            super(MyWindow, self).__init__()
            self.setObjectName("widget")
            self.resize(490, 506)
            self.setMinimumSize(QtCore.QSize(100, 100))
            self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor))
            self.gridLayoutWidget = QtWidgets.QWidget(self)
            self.gridLayoutWidget.setGeometry(QtCore.QRect(60, 120, 381, 301))
            self.gridLayoutWidget.setObjectName("gridLayoutWidget")
            self.gridLayout = QtWidgets.QGridLayout(self.gridLayoutWidget)
            self.gridLayout.setContentsMargins(0, 0, 0, 0)
            self.gridLayout.setObjectName("gridLayout")
            self.label = QtWidgets.QLabel(self)
            self.label.setGeometry(QtCore.QRect(70, 50, 54, 12))
            self.label.setObjectName("label")
            self.textEdit = QtWidgets.QTextEdit(self)
            self.textEdit.setGeometry(QtCore.QRect(120, 45, 261, 25))
            self.textEdit.setObjectName("textEdit")
            self.toolButton = QtWidgets.QToolButton(self)
            self.toolButton.setGeometry(QtCore.QRect(379, 43, 50, 28))
            self.toolButton.setObjectName("toolButton")
            self.toolButton.clicked.connect(self.msg)
            self.pushButton = QtWidgets.QPushButton(self)
            self.pushButton.setGeometry(QtCore.QRect(200, 80, 81, 31))
            self.pushButton.setObjectName("pushButton")
            self.pushButton.clicked.connect(self.sbing)
            #  放图片的label
            self.label2 = QtWidgets.QLabel(self)
            self.label2.setGeometry(QtCore.QRect(72, 150, 360, 300))
            #  参数分别是左上点距左边框宽度,距顶高度,长度,高度
            self.label2.setObjectName("label2")
    
            self.retranslateUi(self)
            QtCore.QMetaObject.connectSlotsByName(self)
        def retranslateUi(self, widget):
            _translate = QtCore.QCoreApplication.translate
            widget.setWindowTitle(_translate("widget", "图片识别器"))
            self.label.setText(_translate("widget", "目标图片"))
            self.toolButton.setText(_translate("widget", "浏览"))
            self.pushButton.setText(_translate("widget", "开始识别"))
        def msg(self):
            '''directory1 = QFileDialog.getExistingDirectory(self,
                                                          "选取文件夹",
                                                          "C:/")  # 起始路径
            print(directory1)'''
    
            fileName1, filetype = QFileDialog.getOpenFileName(self,
                                                              "选取文件",
                                                              "C:/",
                                                              "All Files (*);;Text Files (*.txt)")  # 设置文件扩展名过滤,注意用双分号间隔
            #  print(fileName1, filetype)
            #  print(fileName1)
            '''files, ok1 = QFileDialog.getOpenFileNames(self,
                                                      "多文件选择",
                                                      "C:/",
                                                      "All Files (*);;Text Files (*.txt)")
            print(files, ok1)
    
            fileName2, ok2 = QFileDialog.getSaveFileName(self,
                                                         "文件保存",
                                                         "C:/",
                                                         "All Files (*);;Text Files (*.txt)")
            '''
            png = QtGui.QPixmap(fileName1).scaled(self.label2.width(), self.label2.height())
            self.label2.setPixmap(png)
            self.textEdit.setText(fileName1)
            classify.imgf=fileName1
        def sbing(self):
            self.pushButton.setText("识别中")
            classify.sjsy()
            self.pushButton.setText("开始识别")
    if __name__ == "__main__":
        import sys
        app = QtWidgets.QApplication(sys.argv)
        myshow = MyWindow()
        myshow.show()
        sys.exit(app.exec_())
        exit()
    

    classify

    import sys
    import argparse
    import numpy as np
    from PIL import Image
    import requests
    from io import BytesIO
    import matplotlib.pyplot as plt
    
    from keras.preprocessing import image
    from keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
    
    model = ResNet50(weights='imagenet')
    target_size = (224, 224)
    imgf = ""
    
    def predict(model, img, target_size, top_n=3):
        """Run model prediction on image
      Args:
        model: keras model
        img: PIL format image
        target_size: (w,h) tuple
        top_n: # of top predictions to return
      Returns:
        list of predicted labels and their probabilities
      """
        if img.size != target_size:
            img = img.resize(target_size)
    
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        preds = model.predict(x)
        return decode_predictions(preds, top=top_n)[0]
    
    
    def plot_preds(image, preds):
        """Displays image and the top-n predicted probabilities in a bar graph
      Args:
        image: PIL image
        preds: list of predicted labels and their probabilities
      """
        plt.imshow(image)
        plt.axis('off')
    
        plt.figure()
        order = list(reversed(range(len(preds))))
        bar_preds = [pr[2] for pr in preds]
        labels = (pr[1] for pr in preds)
        plt.barh(order, bar_preds, alpha=0.5)
        plt.yticks(order, labels)
        plt.xlabel('Probability')
        plt.xlim(0, 1.01)
        plt.tight_layout()
        plt.show()
    
    
    def sjsy():
        print(imgf)
        img = Image.open(imgf)
        preds = predict(model, img, target_size)
        plot_preds(img, preds)
    
    
    '''if __name__=="__main__":
      img = Image.open("images/3.jpg")
      preds = predict(model, img, target_size)
      plot_preds(img, preds)
      a = argparse.ArgumentParser()
      a.add_argument("--image", help="path to image")
      a.add_argument("--image_url", help="url to image")
      args = a.parse_args()
    
      if args.image is None and args.image_url is None:
        a.print_help()
        sys.exit(1)
    
      if args.image is not None:
        img = Image.open(args.image)
        preds = predict(model, img, target_size)
        plot_preds(img, preds)
    
      if args.image_url is not None:
        response = requests.get(args.image_url)
        img = Image.open(BytesIO(response.content))
        preds = predict(model, img, target_size)
        plot_preds(img, preds)
  • 相关阅读:
    我的世界Bukkit服务器插件开发教程(九)NMS
    疫情重压下,屈臣氏为何上半年仍盈利?
    Linux常用操作集合
    第六章:利用dumi搭建组件文档【前端工程化入门-----从零实现一个react+ts+vite+tailwindcss组件库】
    什么是嵌套路由?如何定义嵌套路由?
    Spring Boot面试杀手锏————自动配置原理
    springboot常用注解
    6.MySQL内置函数
    UMA Frame Buffer Size 核显显存与CSGO帧率
    视频监控智能分析系统
  • 原文地址:https://blog.csdn.net/cqn2bd2b/article/details/127649650