• ElasticSearch学习笔记-Vector向量搜索记录


    在Elasticsearch 7.0中,ES引入了高维向量的字段类型:

    dense_vector存储稠密向量,value是单一的float数值,可以是0、负数或正数,dense_vector数组的最大长度不能超过1024,每个文档的数组长度可以不同。

    sparse_vector存储稀疏向量,value是单一的float数值,可以是0、负数或正数,sparse_vector存储的是个非嵌套类型的json对象,key是向量的位置,即integer类型的字符串,范围[0,65535]。

    ElasticSearch版本:elasticsearch-7.3.0

    环境准备:

    curl -H "Content-Type: application/json" -XPUT 'http://192.168.0.1:9200/article_v1/' -d '
    {
      "settings": {
        "number_of_shards": 1,
        "number_of_replicas": 0
      },
      "mappings": {
        "dynamic": "strict",
        "properties": {
          "id": {
            "type": "keyword"
          },
          "title": {
            "analyzer": "ik_smart",
            "type": "text"
          },
          "title_dv": {
            "type": "dense_vector",
            "dims": 200
          },
          "title_sv": {
            "type": "sparse_vector"
          }
        }
      }
    }
    '
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    测试验证代码:

    # -*- coding:utf-8 -*-
    
    import os
    import sys
    import jieba
    import logging
    import pymongo
    from elasticsearch import Elasticsearch
    from elasticsearch.serializer import TextSerializer, JSONSerializer
    from gensim.models.doc2vec import TaggedDocument, Doc2Vec
    
    default_encoding = 'utf-8'
    if sys.getdefaultencoding() != default_encoding:
        reload(sys)
        sys.setdefaultencoding(default_encoding)
    
    logging.basicConfig(format='%(asctime)s:%(levelname)s:%(message)s', level=logging.INFO)
    
    # 网上随便爬取一些新闻存入数据库
    client = pymongo.MongoClient(host='192.168.0.1', port=27017)
    db = client['news']
    
    es = Elasticsearch([{'host': '192.168.0.1', 'port': 9200}], timeout=3600)
    
    chinese_stop_words_file = os.path.abspath(os.getcwd() + os.sep + '..' + os.sep + 'static' + os.sep + 'dic' + os.sep + 'chinese_stop_words.txt')
    chinese_stop_words = [line.strip() for line in open(chinese_stop_words_file, 'r').readlines()]
    
    total_cut_word_count = 0
    
    
    # 句子分割
    def sentence_segment(sentence):
        global total_cut_word_count
        result = []
        cut_words = jieba.cut(sentence)
        for cut_word in cut_words:
            if cut_word not in chinese_stop_words:
                result.append(cut_word)
                total_cut_word_count += 1
        return result
    
    
    # 准备语料库
    def prepare_doc_corpus():
        datas = db['netease_ent_news_detail'].find({"create_time": {"$ne": None}}).sort('create_time', pymongo.ASCENDING)
        print datas.count()
        for i, data in enumerate(datas):
            if data['title'] is not None and data['content'] is not None:
                title = str(data['title']).strip()
                yield TaggedDocument(sentence_segment(title), [data['_id']])
    
    
    # 训练模型
    def train_doc_model():
        corpus = prepare_doc_corpus()
        doc2vec = Doc2Vec(vector_size=200, min_count=2, window=5, workers=4, epochs=20)
        doc2vec.build_vocab(corpus)
        doc2vec.train(corpus, total_examples=doc2vec.corpus_count, epochs=doc2vec.epochs)
        doc2vec.save('doc2vec.model')
    
    
    def insert_data_to_es():
        datas = db['netease_ent_news_detail'].find({"create_time": {"$ne": None}}).sort('create_time', pymongo.ASCENDING)
        print datas.count()
        doc2vec = Doc2Vec.load('doc2vec.model')
        for data in datas:
            if data['title'] is not None and data['content'] is not None:
                sentence = str(data['title']).strip()
                title_dv = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
                body = {"id": data['_id'], "title": data['title'], "title_dv": title_dv}
                es_result = es.create(index="article_v1", doc_type="_doc",
                    id=data['_id'], body=body, ignore=[400, 409])
                print es_result
    
    
    # cosineSimilarity函数计算给定文档与索引库里文档的dense_vector相似度
    def search_es_dense_vertor_1(sentence):
        doc2vec = Doc2Vec.load('doc2vec.model')
        query_vector = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
        body = {
            "query": {
                "script_score": {
                    "query": {
                        "match_all": {}
                    },
                    "script": {
                        "source": "cosineSimilarity(params.queryVector, doc['title_dv']) + 1",
                        "params": {
                            "queryVector": query_vector
                        }
                    }
                }
            },
            "from": 0,
            "size": 5
        }
        result = es.search(index="article_v1", body=body)
        hits = result['hits']['hits']
        for hit in hits:
            source = hit['_source']
            for key, value in source.items():
                print '%s %s' % (key, value)
            print '----------'
    
    
    # dotProduct函数计算给定文档与索引库文档点积的距离
    def search_es_dense_vertor_2(sentence):
        doc2vec = Doc2Vec.load('doc2vec.model')
        query_vector = doc2vec.infer_vector(sentence_segment(sentence)).tolist()
        body = {
            "query": {
                "script_score": {
                    "query": {
                        "match_all": {}
                    },
                    "script": {
                        "source": "dotProduct(params.queryVector, doc['title_dv']) + 1",
                        "params": {
                            "queryVector": query_vector
                        }
                    }
                }
            },
            "from": 0,
            "size": 5
        }
        result = es.search(index="article_v1", body=body)
        hits = result['hits']['hits']
        for hit in hits:
            source = hit['_source']
            for key, value in source.items():
                print '%s %s' % (key, value)
            print '----------'
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
  • 相关阅读:
    如何关联分段代码表生成统计报表
    如何提高 Facebook 的运营效率?
    【面经】米哈游大数据开发一面二面面经
    RocketMQ 消费端如何监听消息?
    Java Redis多限流
    Slurm集群调度策略详解(1)-主调度
    Groovy语法大全
    思腾云计算
    新能源车普及的弊端(劝退向)
    ArrayList、LinkedList、Collections.singletonList、Arrays.asList与ImmutableList.of
  • 原文地址:https://blog.csdn.net/jiey0407/article/details/126361622