• ChatGLM系列六:基于知识库的问答


    在这里插入图片描述

    1、安装milvus

    下载milvus-standalone-docker-compose.yml并保存为docker-compose.yml

    wget https://github.com/milvus-io/milvus/releases/download/v2.3.2/milvus-standalone-docker-compose.yml -O docker-compose.yml
    
    • 1

    运行milvus

    sudo docker-compose up -d
    
    • 1

    2、文档预处理

    import os
    import re
    import jieba
    import torch
    import pandas as pd
    from pymilvus import utility
    from pymilvus import connections, CollectionSchema, FieldSchema, Collection, DataType
    from transformers import AutoTokenizer, AutoModel
    
    connections.connect(
        alias="default",
        host='localhost',
        port='19530'
    )
    
    # 定义集合名称和维度
    collection_name = "document"
    dimension = 768
    docs_folder = "./knowledge/"
    
    tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
    model = AutoModel.from_pretrained("bert-base-chinese")
    
    
    # 获取文本的向量
    def get_vector(text):
        input_ids = tokenizer(text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
        with torch.no_grad():
            output = model(input_ids)[0][:, 0, :].numpy()
        return output.tolist()[0]
    
    
    def create_collection():
        # 定义集合字段
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, description="primary id"),
            FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=50),
            FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=10000),
            FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
        ]
    
        # 定义集合模式
        schema = CollectionSchema(fields=fields, description="collection schema")
    
        # 创建集合
    
        if utility.has_collection(collection_name):
        	# 如果你想继续添加新的文档可以直接 return。但你想要重新创建collection,就可以执行下面的代码
            # return
            utility.drop_collection(collection_name)
            collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
            # 创建索引
            default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
            collection.create_index(field_name="vector", index_params=default_index)
            print(f"Collection {collection_name} created successfully")
        else:
            collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
            # 创建索引
            default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
            collection.create_index(field_name="vector", index_params=default_index)
            print(f"Collection {collection_name} created successfully")
    
    
    def init_knowledge():
        collection = Collection(collection_name)
        # 遍历指定目录下的所有文件,并导入到 Milvus 集合中
        docs = []
        for root, dirs, files in os.walk(docs_folder):
            for file in files:
                # 只处理以 .txt 结尾的文本文件
                if file.endswith(".txt"):
                    file_path = os.path.join(root, file)
                    with open(file_path, "r", encoding="utf-8") as f:
                        content = f.read()
                    # 对文本进行清洗处理
                    content = re.sub(r"\s+", " ", content)
                    title = os.path.splitext(file)[0]
                    # 分词
                    words = jieba.lcut(content)
                    # 将分词后的文本重新拼接成字符串
                    content = " ".join(words)
                    # 获取文本向量
                    vector = get_vector(title + content)
                    docs.append({"title": title, "content": content, "vector": vector})
    
        # 将文本内容和向量通过 DataFrame 一起导入集合中
        df = pd.DataFrame(docs)
        collection.insert(df)
        print("Documents inserted successfully")
    
    
    if __name__ == "__main__":
        create_collection()
        init_knowledge()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94

    3、知识库匹配

    通过向量索引库计算出与问题最为相似的文档

    import torch
    from document_preprocess import get_vector
    from pymilvus import Collection
    
    collection = Collection("document")  # Get an existing collection.
    collection.load()
    DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    
    
    # 定义查询函数
    def search_similar_text(input_text):
        # 将输入文本转换为向量
        input_vector = get_vector(input_text)
    	# 查询前三个最匹配的向量ID
        similarity = collection.search(
            data=[input_vector],
            anns_field="vector",
            param={"metric_type": "IP", "params": {"nprobe": 10}, "offset": 0},
            limit=3,
            expr=None,
            consistency_level="Strong"
        )
        ids = similarity[0].ids
        # 通过ID查询出对应的知识库文档
        res = collection.query(
            expr=f"id in {ids}",
            offset=0,
            limit=3,
            output_fields=["id", "content", "title"],
            consistency_level="Strong"
        )
        print(res)
        return res
    
    
    if __name__ == "__main__":
    	question = input('Please enter your question: ')
        search_similar_text(question)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    4、完成回答

    from transformers import AutoModel, AutoTokenizer
    from knowledge_query import search_similar_text
    
    
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
    model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
    model = model.eval()
    
    
    def predict(input, max_length=2048, top_p=0.7, temperature=0.95, history=[]):
    	res = search_similar_text(input)
    	prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
    如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。
    
    已知内容:
    {res}
    
    问题:
    {input}
    """
    	query = prompt_template
    	for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
    	                                           temperature=temperature):
    	    chatbot[-1] = (parse_text(input), parse_text(response))
    	
    	    yield chatbot, history
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    from transformers import AutoModel, AutoTokenizer
    import gradio as gr
    import mdtex2html
    
    from knowledge_query import search_similar_text
    
    tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
    model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
    model = model.eval()
    is_knowledge = True
    
    """Override Chatbot.postprocess"""
    
    
    def postprocess(self, y):
        if y is None:
            return []
        for i, (message, response) in enumerate(y):
            y[i] = (
                None if message is None else mdtex2html.convert((message)),
                None if response is None else mdtex2html.convert(response),
            )
        return y
    
    
    gr.Chatbot.postprocess = postprocess
    
    
    def parse_text(text):
        """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
        lines = text.split("\n")
        lines = [line for line in lines if line != ""]
        count = 0
        for i, line in enumerate(lines):
            if "```" in line:
                count += 1
                items = line.split('`')
                if count % 2 == 1:
                    lines[i] = f'
    {items[-1]}">'
                else:
                    lines[i] = f'
    '
    else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
    "
    +line text = "".join(lines) return text def predict(input, chatbot, max_length, top_p, temperature, history): global is_knowledge chatbot.append((parse_text(input), "")) query = input if is_knowledge: res = search_similar_text(input) prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。 已知内容: {res} 问题: {input} """ query = prompt_template is_knowledge = False for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p, temperature=temperature): chatbot[-1] = (parse_text(input), parse_text(response)) yield chatbot, history def reset_user_input(): return gr.update(value='') def reset_state(): global is_knowledge is_knowledge = False return [], [] with gr.Blocks() as demo: gr.HTML("""

    ChatGLM

    """
    ) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) demo.queue().launch(share=False, inbrowser=True)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
  • 相关阅读:
    ElementUI的表格设置勾选toggleRowSelection
    flask学习笔记
    Python_操作记录
    使用OpenVINO实现人体动作识别
    如何对用户输入进行校验
    二分查找算法
    unocss在vue-cli中的使用
    【配置管理日常管理活动】配置项的控制流程及步骤
    前端实现克里金插值分析(一)
    vue项目TypeScript intellisense is disabled on template.异常解决方案
  • 原文地址:https://blog.csdn.net/qq236237606/article/details/134079126