• deeplearning4j使用vgg19图片向量比对springboot+es环境



    一、桌面创建两个目录读图

    二、POM

    <dependency>
        <groupId>org.springframework.datagroupId>
        <artifactId>spring-data-elasticsearchartifactId>
    dependency>
    
    <dependency>
        <groupId>org.deeplearning4jgroupId>
        <artifactId>deeplearning4j-coreartifactId>
        <version>1.0.0-beta7version>
    dependency>
    <dependency>
        <groupId>org.deeplearning4jgroupId>
        <artifactId>deeplearning4j-zooartifactId>
        <version>1.0.0-beta7version>
    dependency>
    <dependency>
        <groupId>org.elasticsearchgroupId>
        <artifactId>elasticsearchartifactId>
    dependency>
    <dependency>
        <groupId>org.elasticsearch.clientgroupId>
        <artifactId>transportartifactId>
    dependency>
    <dependency>
        <groupId>org.elasticsearch.clientgroupId>
        <artifactId>elasticsearch-rest-clientartifactId>
    dependency>
    <dependency>
        <groupId>org.elasticsearch.plugingroupId>
        <artifactId>transport-netty4-clientartifactId>
    dependency>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    三、code

    import org.datavec.image.loader.NativeImageLoader;
    import org.deeplearning4j.nn.graph.ComputationGraph;
    import org.deeplearning4j.zoo.model.VGG19;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.stereotype.Service;
    import org.springframework.web.multipart.MultipartFile;
    
    import javax.annotation.PostConstruct;
    import java.io.File;
    import java.io.IOException;
    import java.util.HashMap;
    import java.util.Map;
    
    @Service("vgg19Service")
    public class Vgg19ServiceImpl implements Vgg19Service {
    
        private static ComputationGraph vgg19Model;
    
        @PostConstruct
        public void init() throws IOException {
            VGG19 vgg19 = VGG19.builder().build();
            vgg19Model = (ComputationGraph) vgg19.initPretrained();
        }
    
        @Autowired
        private INDArrayPojoRepository indArrayPojoRepository;
    
        @Override
        public String find(MultipartFile file) throws IOException {
    
    //        VGG19 vgg19 = VGG19.builder().build();
    //         vgg19Model = (ComputationGraph) vgg19.initPretrained();
    
    
            String templateImagePath = "C:\\Users\\Administrator\\Desktop\\template\\1.png";
    
            // 图像文件夹路径
            String imageFolder = "C:\\Users\\Administrator\\Desktop\\target";
    
            // 加载模板图像
            NativeImageLoader imageLoader = new NativeImageLoader(224, 224, 3);
            INDArray templateImage = imageLoader.asMatrix(new File(templateImagePath));
    
            // 提取模板图像的特征向量
            INDArray templateFeatures = vgg19Model.outputSingle(templateImage);
    
            // 存储图像相似度的映射
            Map<String, Double> similarityMap = new HashMap<>();
    
            // 遍历图像文件夹
            File folder = new File(imageFolder);
            File[] imageFiles = folder.listFiles();
            long i = 1L;
            indArrayPojoRepository.deleteAll();
            if (imageFiles != null) {
                for (File imageFile : imageFiles) {
                    // 加载当前图像
    //                INDArray currentImage = imageLoader.asMatrix(imageFile);
    //                // 提取当前图像的特征向量
    //                INDArray currentFeatures = vgg19Model.outputSingle(currentImage);
    //                long[] longVector = currentFeatures.toLongVector();
    //                System.out.println(longVector);
    //                double[] doubleVector = currentFeatures.toDoubleVector();
    //                System.out.println(new ImagesArrayPojo(i,doubleVector));
                    indArrayPojoRepository.save( new ImagesArrayPojo(i,new double[]{1,11.11,1}));
    //                indArrayPojoRepository.findBySimilarity(templateFeatures.toDoubleVector(), PageRequest.of(1, 20));
    //                System.out.println(currentFeatures);
    //                // 计算余弦相似度
    //                double similarityScore = Transforms.cosineSim(templateFeatures, currentFeatures);
    //
    //                // 将图像名称和相似度存储到映射中
    //                similarityMap.put(imageFile.getName(), similarityScore);
    
                    i ++;
                }
            }
    
            // 打印相似度最高的三张图像名称
    //        similarityMap.entrySet().stream()
    //                .sorted(Map.Entry.comparingByValue().reversed())
    //                .limit(3)
    //                .forEach(entry -> System.out.println("Image: " + entry.getKey() + ", Similarity: " + entry.getValue()));
    return null;
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    java实体类

    @Data
    @AllArgsConstructor
    @NoArgsConstructor
    @Document(indexName = "images_double")
    public class ImagesArrayPojo {
    
        @Id
        private Long id;
    
        @Field(type = FieldType.Dense_Vector,dims = 1000)
        private double[] ndDoubleArray;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    搭配

    <dependency>
         <groupId>org.springframework.bootgroupId>
         <artifactId>spring-boot-starter-data-elasticsearchartifactId>
     dependency>
    
    • 1
    • 2
    • 3
    • 4

    四、es查询脚本

    这里注意查看官方文档,不同的es脚本写法稍有不同,这里使用的是7.4.2

    docker run -d -e ES_JAVA_OPTS="-Xms128m -Xmx128m" -e "discovery.type=single-node" -e "script.disable_dynamic: false" -p 9200:9200 -p 9300:9300 -e ES_MIN_MEM=128m -e ES_MAX_MEM=4096m --name es elasticsearch:7.4.2 
    
    • 1
    {
      "query": {
        "script_score": {
          "query": {
            "match_all": {}
          },
          "script": {
            "source": "cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0",
            "params": {
              "query_vector": [维度数组]
            }
          }
        }
      }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    五、没测试的代码

    import org.springframework.data.domain.Page;
    import org.springframework.data.domain.Pageable;
    import org.springframework.data.elasticsearch.annotations.Query;
    import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
    import org.springframework.data.repository.query.Param;
    import org.springframework.stereotype.Repository;
    
    @Repository
    public interface INDArrayPojoRepository extends ElasticsearchRepository<ImagesArrayPojo,Long> {
    
        @Query("{\n" +
                "  \"size\": 10,\n" +
                "  \"from\": 0,\n" +
                "  \"query\": {\n" +
                "    \"script_score\": {\n" +
                "      \"query\": {\n" +
                "        \"match_all\": {}\n" +
                "      },\n" +
                "      \"script\": {\n" +
                "        \"source\": \"cosineSimilarity(params.query_vector,doc['ndDoubleArray']) + 1.0\",\n" +
                "        \"params\": {\n" +
                "          \"query_vector\": [?1]\n" +
                "        }\n" +
                "      }\n" +
                "    }\n" +
                "  }\n" +
                "}")
        Page<ImagesArrayPojo> findBySimilarity(@Param("queryVector") double[] queryVector, Pageable pageable);
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    总结

    思路:首先使用deeplearning4j加载vgg19采集图片的向量值,然后将向量值存储到es中,然后后续搜索使用es的余弦脚本查询

  • 相关阅读:
    二次确认弹窗提示
    人工神经网络(ANN)相关介绍
    Vue中的$nextTick有什么作用?
    Flask框架:如何运用Ajax轮询动态绘图
    Windows11安装Maven
    mybatis02
    如何避免CMDB沦为数据孤岛?
    YOLOv5算法进阶改进(5)— 主干网络中引入SCConv | 即插即用的空间和通道维度重构卷积
    02 pycharts 结果生成为 html 、图片(示例代码+效果图)
    算法练习-LeetCode Hot 100 543. 二叉树的直径
  • 原文地址:https://blog.csdn.net/weixin_45479938/article/details/136298929