• 浅谈sparse vec检索工程化实现


    前面我们通过两篇文章: BGE M3-Embedding 模型介绍Sparse稀疏检索介绍与实践 介绍了sparse 稀疏检索,今天我们来看看如何建立一个工程化的系统来实现sparse vec的检索。

    之前提过milvus最新的V2.4支持sparse检索,我们先看看milvus的实现。

    milvus的sparse检索实现#

    milvus 检索底层引擎是knowhere,主要代码在src/index/sparse 里。

    首先,通过数据结构SparseRow,用于表示稀疏向量,支持浮点数(float)类型的数据

    Copy
    class SparseRow { static_assert(std::is_same_v, "SparseRow supports float only"); public: // construct an SparseRow with memory allocated to hold `count` elements. SparseRow(size_t count = 0) : data_(count ? new uint8_t[count * element_size()] : nullptr), count_(count), own_data_(true) { } SparseRow(size_t count, uint8_t* data, bool own_data) : data_(data), count_(count), own_data_(own_data) { } // copy constructor and copy assignment operator perform deep copy SparseRow(const SparseRow& other) : SparseRow(other.count_) { std::memcpy(data_, other.data_, data_byte_size()); } SparseRow(SparseRow&& other) noexcept : SparseRow() { swap(*this, other); } SparseRow& operator=(const SparseRow& other) { if (this != &other) { SparseRow tmp(other); swap(*this, tmp); } return *this; } SparseRow& operator=(SparseRow&& other) noexcept { swap(*this, other); return *this; } ~SparseRow() { if (own_data_ && data_ != nullptr) { delete[] data_; data_ = nullptr; } } size_t size() const { return count_; } size_t memory_usage() const { return data_byte_size() + sizeof(*this); } // return the number of bytes used by the underlying data array. size_t data_byte_size() const { return count_ * element_size(); } void* data() { return data_; } const void* data() const { return data_; } // dim of a sparse vector is the max index + 1, or 0 for an empty vector. int64_t dim() const { if (count_ == 0) { return 0; } auto* elem = reinterpret_cast<const ElementProxy*>(data_) + count_ - 1; return elem->index + 1; } SparseIdVal operator[](size_t i) const { auto* elem = reinterpret_cast<const ElementProxy*>(data_) + i; return {elem->index, elem->value}; } void set_at(size_t i, table_t index, T value) { auto* elem = reinterpret_cast(data_) + i; elem->index = index; elem->value = value; } float dot(const SparseRow& other) const { float product_sum = 0.0f; size_t i = 0; size_t j = 0; // TODO: improve with _mm_cmpistrm or the AVX512 alternative. while (i < count_ && j < other.count_) { auto* left = reinterpret_cast<const ElementProxy*>(data_) + i; auto* right = reinterpret_cast<const ElementProxy*>(other.data_) + j; if (left->index < right->index) { ++i; } else if (left->index > right->index) { ++j; } else { product_sum += left->value * right->value; ++i; ++j; } } return product_sum; } friend void swap(SparseRow& left, SparseRow& right) { using std::swap; swap(left.count_, right.count_); swap(left.data_, right.data_); swap(left.own_data_, right.own_data_); } static inline size_t element_size() { return sizeof(table_t) + sizeof(T); } private: // ElementProxy is used to access elements in the data_ array and should // never be actually constructed. struct __attribute__((packed)) ElementProxy { table_t index; T value; ElementProxy() = delete; ElementProxy(const ElementProxy&) = delete; }; // data_ must be sorted by column id. use raw pointer for easy mmap and zero // copy. uint8_t* data_; size_t count_; bool own_data_; };

    然后索引具体是在InvertedIndex 类里, 对应sparse_inverted_index.h 文件,首先看定义的一些private 字段。

    Copy
    std::vector> raw_data_; mutable std::shared_mutex mu_; std::unordered_map<table_t, std::vector>> inverted_lut_; bool use_wand_ = false; // If we want to drop small values during build, we must first train the // index with all the data to compute value_threshold_. bool drop_during_build_ = false; // when drop_during_build_ is true, any value smaller than value_threshold_ // will not be added to inverted_lut_. value_threshold_ is set to the // drop_ratio_build-th percentile of all absolute values in the index. T value_threshold_ = 0.0f; std::unordered_map<table_t, T> max_in_dim_; size_t max_dim_ = 0;
    • raw_data_ 是原始的数据
    • inverted_lut_ 可以理解为一个倒排表
    • use_wand_ 用于控制查询时,是否使用WAND算法,WAND算法是经典的查询优化算法,可以通过类似跳表的方式跳过一些数据,减少计算量,提升查询效率
    • max_in_dim_ 是为wand服务的

    索引构建流程#

    构建,主要是对外提供一个Add数据的方法:

    Copy
    Status Add(const SparseRow* data, size_t rows, int64_t dim) { std::unique_lock lock(mu_); auto current_rows = n_rows_internal(); if (current_rows > 0 && drop_during_build_) { LOG_KNOWHERE_ERROR_ << "Not allowed to add data to a built index with drop_ratio_build > 0."; return Status::invalid_args; } if ((size_t)dim > max_dim_) { max_dim_ = dim; } raw_data_.insert(raw_data_.end(), data, data + rows); for (size_t i = 0; i < rows; ++i) { add_row_to_index(data[i], current_rows + i); } return Status::success; }

    这里会更新数据的max_dim,数据追加到raw_data_,然后add_row_to_index,将新的doc放入inverted_lut_, 并更新max_in_dim_,用于记录最大值,方便wand查询时跳过计算。

    Copy
    inline void add_row_to_index(const SparseRow& row, table_t id) { for (size_t j = 0; j < row.size(); ++j) { auto [idx, val] = row[j]; // Skip values close enough to zero(which contributes little to // the total IP score). if (drop_during_build_ && fabs(val) < value_threshold_) { continue; } if (inverted_lut_.find(idx) == inverted_lut_.end()) { inverted_lut_[idx]; if (use_wand_) { max_in_dim_[idx] = 0; } } inverted_lut_[idx].emplace_back(id, val); if (use_wand_) { max_in_dim_[idx] = std::max(max_in_dim_[idx], val); } } }

    索引保存与load#

    保存时,是自定义的二进制文件:

    Copy
    Status Save(MemoryIOWriter& writer) { /** * zero copy is not yet implemented, now serializing in a zero copy * compatible way while still copying during deserialization. * * Layout: * * 1. int32_t rows, sign indicates whether to use wand * 2. int32_t cols * 3. for each row: * 1. int32_t len * 2. for each non-zero value: * 1. table_t idx * 2. T val * With zero copy deserization, each SparseRow object should * reference(not owning) the memory address of the first element. * * inverted_lut_ and max_in_dim_ not serialized, they will be * constructed dynamically during deserialization. * * Data are densly packed in serialized bytes and no padding is added. */ std::shared_lock lock(mu_); writeBinaryPOD(writer, n_rows_internal() * (use_wand_ ? 1 : -1)); writeBinaryPOD(writer, n_cols_internal()); writeBinaryPOD(writer, value_threshold_); for (size_t i = 0; i < n_rows_internal(); ++i) { auto& row = raw_data_[i]; writeBinaryPOD(writer, row.size()); if (row.size() == 0) { continue; } writer.write(row.data(), row.size() * SparseRow::element_size()); } return Status::success; }

    索引文件格式:

      1. int32_t rows 总记录数,通过±符号来区分是否 use wand
      1. int32_t cols 列数
      1. for each row:
    • Copy
      1. int32_t len 长度
    • Copy
      2. for each non-zero value:
    • Copy
      1. table_t idx term的id编号
    • Copy
      2. T val term的权重

    注意,这里inverted_lut_倒排表是没有存储的,是在加载的时候重建,所以load的过程,就是一个逆过程:

    Copy
    Status Load(MemoryIOReader& reader) { std::unique_lock lock(mu_); int64_t rows; readBinaryPOD(reader, rows); use_wand_ = rows > 0; rows = std::abs(rows); readBinaryPOD(reader, max_dim_); readBinaryPOD(reader, value_threshold_); raw_data_.reserve(rows); for (int64_t i = 0; i < rows; ++i) { size_t count; readBinaryPOD(reader, count); raw_data_.emplace_back(count); if (count == 0) { continue; } reader.read(raw_data_[i].data(), count * SparseRow::element_size()); add_row_to_index(raw_data_[i], i); } return Status::success; }

    检索流程#

    我们来回顾,compute_lexical_matching_score其实就是计算共同term的weight score相乘,然后加起来,所以可以想象下,暴力检索大概就是把所有term对应的doc取并集,然后计算lexical_matching_score,最后取topk。

    我们来看milvus的实现,先看暴力检索:

    Copy
    // find the top-k candidates using brute force search, k as specified by the capacity of the heap. // any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored. // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed. void search_brute_force(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, const BitsetView& bitset) const { auto scores = compute_all_distances(q_vec, q_threshold); for (size_t i = 0; i < n_rows_internal(); ++i) { if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) { heap.push(i, scores[i]); } } } std::vector<float> compute_all_distances(const SparseRow& q_vec, T q_threshold) const { std::vector<float> scores(n_rows_internal(), 0.0f); for (size_t idx = 0; idx < q_vec.size(); ++idx) { auto [i, v] = q_vec[idx]; if (v < q_threshold || i >= n_cols_internal()) { continue; } auto lut_it = inverted_lut_.find(i); if (lut_it == inverted_lut_.end()) { continue; } // TODO: improve with SIMD auto& lut = lut_it->second; for (size_t j = 0; j < lut.size(); j++) { auto [idx, val] = lut[j]; scores[idx] += v * float(val); } } return scores; }
    • 核心在compute_all_distances里,先通过q_vec得到每一个term id,然后从inverted_lut_里找到term对应的doc list,然后计算score,相同doc id的score累加
    • 最后用MaxMinHeap堆,来取topk

    暴力检索能保准精准性,但是效率比较低。我们来看使用wand优化的检索:

    Copy
    // any value in q_vec that is smaller than q_threshold will be ignored. void search_wand(const SparseRow& q_vec, T q_threshold, MaxMinHeap& heap, const BitsetView& bitset) const { auto q_dim = q_vec.size(); std::vector>>>> cursors(q_dim); auto valid_q_dim = 0; // 倒排链 for (size_t i = 0; i < q_dim; ++i) { // idx(term_id) auto [idx, val] = q_vec[i]; if (std::abs(val) < q_threshold || idx >= n_cols_internal()) { continue; } auto lut_it = inverted_lut_.find(idx); if (lut_it == inverted_lut_.end()) { continue; } auto& lut = lut_it->second; // max_in_dim_ 记录了term index 的最大score cursors[valid_q_dim++] = std::make_shared>>>( lut, n_rows_internal(), max_in_dim_.find(idx)->second * val, val, bitset); } if (valid_q_dim == 0) { return; } cursors.resize(valid_q_dim); auto sort_cursors = [&cursors] { std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); }); }; sort_cursors(); // 堆未满,或者新的score > 堆顶的score auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; }; while (true) { // 上边界 float upper_bound = 0; // pivot 满足条件的倒排链的序号 size_t pivot; bool found_pivot = false; for (pivot = 0; pivot < cursors.size(); ++pivot) { // 有倒排结束 if (cursors[pivot]->is_end()) { break; } upper_bound += cursors[pivot]->max_score(); if (score_above_threshold(upper_bound)) { found_pivot = true; break; } } if (!found_pivot) { break; } // 找到满足upper_bound 满足条件的pivot_id table_t pivot_id = cursors[pivot]->cur_vec_id(); // 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score if (pivot_id == cursors[0]->cur_vec_id()) { float score = 0; // 遍历所有cursors,累加score for (auto& cursor : cursors) { if (cursor->cur_vec_id() != pivot_id) { break; } score += cursor->cur_distance() * cursor->q_value(); // 倒排链移到下一位 cursor->next(); } // 放入堆 heap.push(pivot_id, score); // 重排cursors,保证最小的vec_id在最前面 sort_cursors(); } else { // 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1 // 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的 size_t next_list = pivot; for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) { } // 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id // seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id cursors[next_list]->seek(pivot_id); // 从next_list + 1开始 for (size_t i = next_list + 1; i < cursors.size(); ++i) { // 如果当前cur_vec_id >= 上一个则停止 if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) { break; } // 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面 std::swap(cursors[i], cursors[i - 1]); } } } }
    • 首先是倒排链取出来放入cursors,然后对cursors按照vec_id排序,将vec_id较小的排到倒排链的首位
    • 通过score_above_threshold,遍历cursors找符合条件的cursor 索引号pivot,这里通过堆未满,或者新的score > 堆顶的score来判断,可以跳过一些score小的
    • 然后找到pivot cursor对应的pivot_id,也就是doc id,然后判断第一个倒排链的cur_vec_id 是否等于pivot_id:
      • 如果等于,就可以遍历倒排链,计算pivot_id的score,然后放入小顶堆中排序,然后重排倒排链
      • 如果不等于,那么就需要想办法将cur_vec_id == pivot_id的往前放,同时跳过倒排链中vec_id < cur_vec_id的数据(减枝)

    用golang实现轻量级sparse vec检索#

    用类似milvus的方法,我们简单实现一个golang版本的

    Copy
    package main import ( "container/heap" "encoding/binary" "fmt" "io" "math/rand" "os" "sort" "time" ) type Cursor struct { docIDs []int32 weights []float64 maxScore float64 termWeight float64 currentIdx int } func NewCursor(docIDs []int32, weights []float64, maxScore float64, weight float64) *Cursor { return &Cursor{ docIDs: docIDs, weights: weights, maxScore: maxScore, termWeight: weight, currentIdx: 0, } } func (c *Cursor) Next() { c.currentIdx++ } func (c *Cursor) Seek(docId int32) { for { if c.IsEnd() { break } if c.CurrentDocID() < docId { c.Next() } else { break } } } func (c *Cursor) IsEnd() bool { return c.currentIdx >= len(c.docIDs) } func (c *Cursor) CurrentDocID() int32 { return c.docIDs[c.currentIdx] } func (c *Cursor) CurrentDocWeight() float64 { return c.weights[c.currentIdx] } // DocVectors type will map docID to its vector type DocVectors map[int32]map[int32]float64 // InvertedIndex type will map termID to sorted list of docIDs type InvertedIndex map[int32][]int32 // TermMaxScore will keep track of maximum scores for terms type TermMaxScores map[int32]float64 // SparseIndex class struct type SparseIndex struct { docVectors DocVectors invertedIndex InvertedIndex termMaxScores TermMaxScores dim int32 } // NewSparseIndex initializes a new SparseIndex with empty structures func NewSparseIndex() *SparseIndex { return &SparseIndex{ docVectors: make(DocVectors), invertedIndex: make(InvertedIndex), termMaxScores: make(TermMaxScores), dim: 0, } } // Add method for adding documents to the sparse index func (index *SparseIndex) Add(docID int32, vec map[int32]float64) { index.docVectors[docID] = vec for termID, score := range vec { index.invertedIndex[termID] = append(index.invertedIndex[termID], docID) // Track max score for each term if maxScore, ok := index.termMaxScores[termID]; !ok || score > maxScore { index.termMaxScores[termID] = score } if termID > index.dim { index.dim = termID } } } // Save index to file func (index *SparseIndex) Save(filename string) error { file, err := os.Create(filename) if err != nil { return err } defer file.Close() // Write the dimension binary.Write(file, binary.LittleEndian, index.dim) // Write each document vector for docID, vec := range index.docVectors { binary.Write(file, binary.LittleEndian, docID) vecSize := int32(len(vec)) binary.Write(file, binary.LittleEndian, vecSize) for termID, score := range vec { binary.Write(file, binary.LittleEndian, termID) binary.Write(file, binary.LittleEndian, score) } } return nil } // Load index from file func (index *SparseIndex) Load(filename string) error { file, err := os.Open(filename) if err != nil { return err } defer file.Close() var dim int32 binary.Read(file, binary.LittleEndian, &dim) index.dim = dim for { var docID int32 err := binary.Read(file, binary.LittleEndian, &docID) if err == io.EOF { break // End of file } else if err != nil { return err // Some other error } var vecSize int32 binary.Read(file, binary.LittleEndian, &vecSize) vec := make(map[int32]float64) for i := int32(0); i < vecSize; i++ { var termID int32 var score float64 binary.Read(file, binary.LittleEndian, &termID) binary.Read(file, binary.LittleEndian, &score) vec[termID] = score } index.Add(docID, vec) // Rebuild the index } return nil } func (index *SparseIndex) bruteSearch(queryVec map[int32]float64, K int) []int32 { scores := computeAllDistances(queryVec, index) // 取top k docHeap := &DocScoreHeap{} for docID, score := range scores { if docHeap.Len() < K { heap.Push(docHeap, &DocScore{docID, score}) } else if (*docHeap)[0].score < score { heap.Pop(docHeap) heap.Push(docHeap, &DocScore{docID, score}) } } topDocs := make([]int32, 0, K) for docHeap.Len() > 0 { el := heap.Pop(docHeap).(*DocScore) topDocs = append(topDocs, el.docID) } sort.Slice(topDocs, func(i, j int) bool { return topDocs[i] < topDocs[j] }) return topDocs } func computeAllDistances(queryVec map[int32]float64, index *SparseIndex) map[int32]float64 { scores := make(map[int32]float64) for term, qWeight := range queryVec { if postingList, exists := index.invertedIndex[term]; exists { for _, docID := range postingList { docVec := index.docVectors[docID] docWeight, exists := docVec[term] if !exists { continue } score := qWeight * docWeight if _, ok := scores[docID]; !ok { scores[docID] = score } else { scores[docID] += score } } } } return scores } // TopK retrieves the top K documents nearest to the query vector func (index *SparseIndex) WandSearch(queryVec map[int32]float64, K int) []int32 { docHeap := &DocScoreHeap{} // 倒排链 postingLists := make([]*Cursor, len(queryVec)) idx := 0 for term, termWeight := range queryVec { if postingList, exists := index.invertedIndex[term]; exists { // 包含term的doc,term对应的weight weights := make([]float64, len(postingList)) for i, docID := range postingList { weights[i] = index.docVectors[docID][term] } postingLists[idx] = NewCursor(postingList, weights, index.termMaxScores[term]*termWeight, termWeight) idx += 1 } } sortPostings := func() { for i := range postingLists { if postingLists[i].IsEnd() { return } } // 将postingLists按照首个docid排序 sort.Slice(postingLists, func(i, j int) bool { return postingLists[i].CurrentDocID() < postingLists[j].CurrentDocID() }) } sortPostings() scoreAboveThreshold := func(value float64) bool { return docHeap.Len() < K || (*docHeap)[0].score < value } for { upperBound := 0.0 foundPivot := false pivot := 0 for idx := range postingLists { if postingLists[idx].IsEnd() { break } upperBound += postingLists[idx].maxScore if scoreAboveThreshold(upperBound) { foundPivot = true pivot = idx break } } if !foundPivot { break } // 找到满足upper_bound 满足条件的pivot_id pivotId := postingLists[pivot].CurrentDocID() if pivotId == postingLists[0].CurrentDocID() { // 如果第一个倒排链的当前vec_id (doc_id) 等于pivot_id,可以直接从第0个倒排链开始,计算score score := 0.0 // 遍历所有cursors,累加score for idx := range postingLists { cursor := postingLists[idx] if cursor.CurrentDocID() != pivotId { break } score += cursor.CurrentDocWeight() * cursor.termWeight // 移到下一个docid postingLists[idx].Next() } // 放入堆s if docHeap.Len() < K { heap.Push(docHeap, &DocScore{pivotId, score}) } else if (*docHeap)[0].score < score { heap.Pop(docHeap) heap.Push(docHeap, &DocScore{pivotId, score}) } // 重排cursors,保证最小的vec_id在最前面 sortPostings() } else { // 第一个倒排链的当前vec_id不等于pivot_id, pivot>=1 // 那么从pivot(满足threshold的倒排链序号)往前找是否有cur_vec_id==pivot_id的 nextList := pivot for ; postingLists[nextList].CurrentDocID() == pivotId; nextList-- { } // 这里的next_list的cur_vec_id 不一定等与pivot_id,将list seek到pivot_id // seek后,cursors[next_list].cur_vec_id() >= pivot_id,通过seek,可以跳过一些vec id postingLists[nextList].Seek(pivotId) // 从next_list + 1开始 for i := nextList + 1; i < len(postingLists); i++ { // 如果当前cur_vec_id >= 上一个则停止 if postingLists[i].CurrentDocID() >= postingLists[i-1].CurrentDocID() { break } // 否则,交换倒排链,可以确保==pivot_id的倒排链交换到前面 temp := postingLists[i] postingLists[i] = postingLists[i-1] postingLists[i-1] = temp } } } topDocs := make([]int32, 0, K) for docHeap.Len() > 0 { el := heap.Pop(docHeap).(*DocScore) topDocs = append(topDocs, el.docID) } sort.Slice(topDocs, func(i, j int) bool { return topDocs[i] < topDocs[j] }) return topDocs } // Helper structure to manage the priority queue for the top-K documents type DocScore struct { docID int32 score float64 } type DocScoreHeap []*DocScore func (h DocScoreHeap) Len() int { return len(h) } func (h DocScoreHeap) Less(i, j int) bool { return h[i].score < h[j].score } func (h DocScoreHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *DocScoreHeap) Push(x interface{}) { *h = append(*h, x.(*DocScore)) } func (h *DocScoreHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] *h = old[0 : n-1] return x } func main() { index := NewSparseIndex() rand.Seed(time.Now().UnixNano()) // Add document vectors as needed for i := 1; i <= 1000; i++ { // 打印当前i的值 index.Add(int32(i), map[int32]float64{101: rand.Float64(), 150: rand.Float64(), 190: rand.Float64(), 500: rand.Float64()}) } //index.Save("index.bin") //index.Load("index.bin") topDocs := index.WandSearch(map[int32]float64{101: rand.Float64(), 150: rand.Float64(), 190: rand.Float64(), 500: rand.Float64()}, 10) fmt.Println("Top Docs:", topDocs) }
    • 代码实现了索引的构建、保存和加载,检索方面实现了暴力检索和WAND检索
    • 注意,添加doc时,需要保障doc有序,实际应用中,docid可以引擎维护一个真实id到递增docid的映射
    • 代码中已经有注释,这里不再赘述,注意代码未充分调试,可能有bug
    • 代码实现倒排表全放到内存,效率高,但对内存要求高

    总结#

    sparse 检索整体类似传统的文本检索,因此传统的工程优化方法可以运用到sparse检索中,本文分析了milvus的实现,并实现了一个golang版本的sparse检索。

    关注作者

    欢迎关注作者微信公众号, 一起交流软件开发:欢迎关注作者微信公众号

  • 相关阅读:
    Java的三种技术架构是什么?
    数字孪生技术的实用价值在哪里?用四个案例为你解答
    CVE-2017-7529 Nginx越界读取内存漏洞
    java计算机毕业设计田径运动会管理系统源程序+mysql+系统+lw文档+远程调试
    【JavaScript】String对象知识全解
    Redis实现布隆过滤器(上)
    【32. 图中的层次(图的广度优先遍历)】
    vue cli 打包、生产环境http-proxy-middleware代理
    【Hack The Box】linux练习-- Haircut
    33.2.4 配置Mycat负载均衡
  • 原文地址:https://www.cnblogs.com/xiaoqi/p/18150639/golang-sparse-retrival