• ChatGLM系列八:微调医疗问答系统


    一、ChatGLM2-6B

    ChatGLM2-6B 是 ChatGLM-6B 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,同时引入了许多新特性,如:更强大的性能、更长的上下文、更高效的推理、更开放的协议等。

    二、P-tuning v2

    P-tuning v2 微调技术利用 deep prompt tuning,即对预训练 Transformer 的每一层输入应用 continuous prompts 。deep prompt tuning 增加了 continuo us prompts 的能力,并缩小了跨各种设置进行微调的差距,特别是对于小型模型和困难任务。
    在这里插入图片描述
    上图左边为 P-Tuning,右边为P-Tuning v2。P-Tuning v2 层与层之间的 continuous prompt 是相互独立的。

    三、ChatGLM2-6B 模型下载

    huggingface 地址:https://huggingface.co/THUDM/chatglm2-6b/tree/main
    
    • 1

    在这里插入图片描述

    四、数据集下载

    https://github.com/Toyhom/Chinese-medical-dialogue-data
    
    • 1

    数据格式:
    在这里插入图片描述

    五、数据预处理

    import json
    import pandas as pd
    
    data_path = [
        "./data/Chinese-medical-dialogue-data-master/Data_数据/IM_内科/内科5000-33000.csv",
        "./data/Chinese-medical-dialogue-data-master/Data_数据/Oncology_肿瘤科/肿瘤科5-10000.csv",
        "./data/Chinese-medical-dialogue-data-master/Data_数据/Pediatric_儿科/儿科5-14000.csv",
        "./data/Chinese-medical-dialogue-data-master/Data_数据/Surgical_外科/外科5-14000.csv",
    ]
    
    train_json_path = "./data/train.json"
    val_json_path = "./data/val.json"
    # 每个数据取 10000 条作为训练
    train_size = 10000
    # 每个数据取 2000 条作为验证
    val_size = 2000
    
    
    def doHandler():
        train_f = open(train_json_path, "a", encoding='utf-8')
        val_f = open(val_json_path, "a", encoding='utf-8')
        for path in data_path:
            data = pd.read_csv(path, encoding='ANSI')
            train_count = 0
            val_count = 0
            for index, row in data.iterrows():
                ask = row["ask"]
                answer = row["answer"]
                line = {
                    "content": ask,
                    "summary": answer
                }
                line = json.dumps(line, ensure_ascii=False)
                if train_count < train_size:
                    train_f.write(line + "\n")
                    train_count = train_count + 1
                elif val_count < val_size:
                    val_f.write(line + "\n")
                    val_count = val_count + 1
                else:
                    break
        print("数据处理完毕!")
        train_f.close()
        val_f.close()
    
    
    if __name__ == '__main__':
        doHandler()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    六、模型微调训练

    git clone https://github.com/THUDM/ChatGLM2-6B
    pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
    pip install rouge_chinese nltk jieba datasets -i https://pypi.tuna.tsinghua.edu.cn/simple
    
    • 1
    • 2
    • 3

    修改 ptuning 下的 train.sh 文件:

    PRE_SEQ_LEN=300
    LR=2e-2
    NUM_GPUS=1
    
    torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \
        --do_train \
        --train_file data/train.json \
        --validation_file data/val.json \
        --preprocessing_num_workers 10 \
        --prompt_column content \
        --response_column summary \
        --overwrite_cache \
        --model_name_or_path /home/chatglm2/chatglm-6b \
        --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
        --overwrite_output_dir \
        --max_source_length 300 \
        --max_target_length 1024 \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --predict_with_generate \
        --max_steps 3000 \
        --logging_steps 10 \
        --save_steps 1000 \
        --learning_rate $LR \
        --pre_seq_len $PRE_SEQ_LEN \
        --quantization_bit 4
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    –standalone` 以单机模式训练。
    –nnodes` 节点数。这里只有一个节点,设置为 1。
    –nproc-per-node` 每个节点上的进程数。
    –do_train` 执行训练任务。
    –train_file` 训练数据文件路径, 上面生成的 train.json 文件。
    –validation_file` 验证数据文件路径, 上面生成的 val.json 文件。
    –preprocessing_num_workers` 指定数据预处理时的 workers 数。
    –prompt_column` 输入信息的字段名称。
    –response_column` 输出信息的字段名称。
    –overwrite_cache` 覆盖缓存文件。
    –model_name_or_path` 预训练模型的名称或路径,注意这里我是用的下载后的模型存放地址,需要修改为你的。
    –output_dir` 模型保存目录。
    –overwrite_output_dir` 覆盖输出目录。
    –max_source_length` 输入文本的最大长度。
    –max_target_length` 输出文本的最大长度。
    –per_device_train_batch_size` 训练时的批次大小。
    –per_device_eval_batch_size` 验证时的批次大小。
    –gradient_accumulation_steps` 累积多少个梯度之后再进行一次反向传播。
    –predict_with_generate` 预测时使用生成模式。
    –max_steps` 最大训练轮数。
    –logging_steps` 多少轮打印一次日志。
    –save_steps` 多少轮保存一次模型。
    –learning_rate` 初始学习率。
    –pre_seq_len` 预处理时选取的序列长度。
    –quantization_bit` 量化位大小。
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    执行脚本

    bash train.sh
    
    • 1

    七、模型测试

    from fastapi import FastAPI, Request
    from fastapi.middleware.cors import CORSMiddleware
    from transformers import AutoTokenizer, AutoModel, AutoConfig
    import uvicorn, json, datetime
    import torch
    import os
    
    
    def main():
        pre_seq_len = 300
        # 训练权重地址
        checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"
    
        tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
        config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
        model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        # 量化
        model = model.quantize(4)
        model.eval()
    
        # 问题
        question = "突然感到了不适,去检查后竟然得了这个病,请问:宝宝白天爱磨牙会是哪些情况呢"
    
        response, history = model.chat(tokenizer,
                                       question,
                                       history=[],
                                       max_length=2048,
                                       top_p=0.7,
                                       temperature=0.95)
    
        print("回答:", response)
    
        if torch.backends.mps.is_available():
            torch.mps.empty_cache()
    
    
    if __name__ == '__main__':
        main()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    from fastapi import FastAPI, Request
    from fastapi.middleware.cors import CORSMiddleware
    from transformers import AutoTokenizer, AutoModel, AutoConfig
    import uvicorn, json, datetime
    import torch
    import os
    
    app = FastAPI()
    
    # 允许所有域的请求
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    
    
    @app.post("/")
    async def create_item(request: Request):
        global model, tokenizer
        json_post_raw = await request.json()
        json_post = json.dumps(json_post_raw)
        json_post_list = json.loads(json_post)
        prompt = json_post_list.get('prompt')
        history = json_post_list.get('history')
        max_length = json_post_list.get('max_length')
        top_p = json_post_list.get('top_p')
        temperature = json_post_list.get('temperature')
        response, history = model.chat(tokenizer,
                                       prompt,
                                       history=history,
                                       max_length=max_length if max_length else 2048,
                                       top_p=top_p if top_p else 0.7,
                                       temperature=temperature if temperature else 0.95)
        now = datetime.datetime.now()
        time = now.strftime("%Y-%m-%d %H:%M:%S")
        answer = {
            "response": response,
            "history": history,
            "status": 200,
            "time": time
        }
        log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
        print(log)
        if torch.backends.mps.is_available():
            torch.mps.empty_cache()
        return answer
    
    
    if __name__ == '__main__':
        pre_seq_len = 300
        checkpoint_path = "ptuning/output/adgen-chatglm2-6b-pt-300-2e-2/checkpoint-3000"
    
        tokenizer = AutoTokenizer.from_pretrained("chatglm-6b", trust_remote_code=True)
        config = AutoConfig.from_pretrained("chatglm-6b", trust_remote_code=True, pre_seq_len=pre_seq_len)
        model = AutoModel.from_pretrained("chatglm-6b", config=config, device_map="auto", trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        ## 量化
        model = model.quantize(4)
        model = model.cuda()
        model.eval()
        uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
  • 相关阅读:
    【Java 进阶篇】CSS语法格式详解
    Golang中for循环使用
    糟糕,数据库异常不可用怎么办?
    1. Vue项目中element-ui版本进行升级
    阿里云云效 Maven
    C++图书管理案例
    Handle
    一文搞懂│什么是跨域?如何解决跨域?
    《进化优化》第1章 绪论
    html2canvas 行内元素边框样式生成问题解决(根据文字生成图片)
  • 原文地址:https://blog.csdn.net/qq236237606/article/details/134080707