• 基于tornado BELLE 搭建本地的web 服务


    我的github

    将BELLE 封装成web 后端服务,采用tornado 框架
    import time
    
    import torch
    import torch.nn as nn
    
    from gptq import *
    from modelutils import *
    from quant import *
    
    from transformers import AutoTokenizer
    import sys
    import json
    #import lightgbm as lgb
    import logging
    import tornado.escape
    import tornado.ioloop
    import tornado.web
    import traceback
    DEV = torch.device('cuda:0')
    
    def get_bloom(model):
        import torch
        def skip(*args, **kwargs):
            pass
        torch.nn.init.kaiming_uniform_ = skip
        torch.nn.init.uniform_ = skip
        torch.nn.init.normal_ = skip
        from transformers import BloomForCausalLM
        model = BloomForCausalLM.from_pretrained(model, torch_dtype='auto')
        model.seqlen = 2048
        return model
    
    def load_quant(model, checkpoint, wbits, groupsize):
        from transformers import BloomConfig, BloomForCausalLM 
        config = BloomConfig.from_pretrained(model)
        def noop(*args, **kwargs):
            pass
        torch.nn.init.kaiming_uniform_ = noop 
        torch.nn.init.uniform_ = noop 
        torch.nn.init.normal_ = noop 
    
        torch.set_default_dtype(torch.half)
        transformers.modeling_utils._init_weights = False
        torch.set_default_dtype(torch.half)
        model = BloomForCausalLM(config)
        torch.set_default_dtype(torch.float)
        model = model.eval()
        layers = find_layers(model)
        for name in ['lm_head']:
            if name in layers:
                del layers[name]
        make_quant(model, layers, wbits, groupsize)
    
        print('Loading model ...')
        if checkpoint.endswith('.safetensors'):
            from safetensors.torch import load_file as safe_load
            model.load_state_dict(safe_load(checkpoint))
        else:
            model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cuda')))
        model.seqlen = 2048
        print('Done.')
    
        return model
    
    
    import argparse
    from datautils import *
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        'model', type=str,
        help='llama model to load'
    )
    parser.add_argument(
        '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
        help='#bits to use for quantization; use 16 for evaluating base model.'
    )
    parser.add_argument(
        '--groupsize', type=int, default=-1,
        help='Groupsize to use for quantization; default uses full row.'
    )
    parser.add_argument(
        '--load', type=str, default='',
        help='Load quantized model.'
    )
    
    parser.add_argument(
        '--text', type=str,
        help='hello'
    )
    
    parser.add_argument(
        '--min_length', type=int, default=10,
        help='The minimum length of the sequence to be generated.'
    )
    
    parser.add_argument(
        '--max_length', type=int, default=1024,
        help='The maximum length of the sequence to be generated.'
    )
    
    parser.add_argument(
        '--top_p', type=float , default=0.95,
        help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'
    )
    
    parser.add_argument(
        '--temperature', type=float, default=0.8,
        help='The value used to module the next token probabilities.'
    )
    
    args = parser.parse_args()
    
    if type(args.load) is not str:
        args.load = args.load.as_posix()
    
    if args.load:
        model = load_quant(args.model, args.load, args.wbits, args.groupsize)
    else:
        model = get_bloom(args.model)
        model.eval()
        
    model.to(DEV)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    print("Human:")
    
    inputs = 'Human: ' +'hello' + '\n\nAssistant:'
    input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)
    """
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids,
            do_sample=True,
            min_length=args.min_length,
            max_length=args.max_length,
            top_p=args.top_p,
            temperature=args.temperature,
        )
    print("Assistant:\n") 
    print(tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant 
    print("\n-------------------------------\n")
    
    """
    #python bloom_inference.py BELLE_BLOOM_GPTQ_4BIT  --temperature 1.2  --wbits 4 --groupsize 128 --load  BELLE_BLOOM_GPTQ_4BIT/bloom7b-2m-4bit-128g.pt
    class GateAPIHandler(tornado.web.RequestHandler):
        def initialize(self):
            self.set_header("Content-Type", "application/text")
            self.set_header("Access-Control-Allow-Origin", "*")
    
    
        async def post(self):
    
            print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")
            postArgs = self.request.body_arguments
    
            print( postArgs)
            if (not 'status' in postArgs):
                return tornado.web.HTTPError(400)
            try:
                json_str = postArgs.get("status")[0]
    #            req = json.loads(json_str)
                print(json_str)
                #logging.error("recieve time : {0} . player id : {1}".format(str(time.time()), str(req["playerID"])))
                inputs = 'Human: ' +json_str.decode('utf-8') + '\n\nAssistant:'
                input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)
                
                with torch.no_grad():
                    generated_ids = model.generate(
                        input_ids,
                        do_sample=True,
                        min_length=args.min_length,
                        max_length=args.max_length,
                        top_p=args.top_p,
                        temperature=args.temperature,
                    )
                print("Assistant:\n")
                answer=tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]
                print(answer) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant 
                result = {'belle':answer}
                pred_str = str(json.dumps(result))
                self.write(pred_str)
                #logging.error("callback time : {0} . player id : {1}, result:{2}".format(str(time.time()), str(playerID), pred_str))
            except Exception as e:
                logging.error("Error: {0}.".format(e))
                traceback.print_exc()
                raise tornado.web.HTTPError(500)
    
        def get(self):
            raise tornado.web.HTTPError(300)
    
    
    import logging
    import tornado.autoreload
    import tornado.ioloop
    import tornado.options
    import tornado.web
    import tornado.httpserver
    #import   itempredict
    import argparse
    from tornado.httpserver import HTTPServer
    
    
    
    
    
    #trace()
    if __name__ == "__main__":
        tornado.options.define("port", default=8081,type=int, help="This is a port number",
                               metavar=None, multiple=False, group=None, callback=None)
        tornado.options.parse_command_line()
        app = tornado.web.Application([
            (r"/", GateAPIHandler),
        ])
        apiport = tornado.options.options.port
        app.listen(apiport)
        logging.info("Start Gate API server on port {0}.".format(apiport))
    
        server = HTTPServer(app)
        server.start(1)
        #trace()
        #tornado.autoreload.start()
        tornado.ioloop.IOLoop.instance().start()
                                                 
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    import base64
    import json
    import time
    import requests
    from utils.ops import read_wav_bytes
    
    URL = 'http://192.168.3.9:8081'
    
    #wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('out.wav')
    data = {
        'status': ' 如何理解黑格尔的 量变引起质变规律和否定之否定规律',
    
    }
    
    
    t0=time.time()
    r = requests.post(URL,  data=data)
    t1=time.time()
    r.encoding='utf-8'
    
    result = json.loads(r.text)
    print(result)
    print('time:', t1-t0, 's')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    在这里插入图片描述

  • 相关阅读:
    阿基米德优化算法(Matlab代码实现)
    基于TextRank算法生成文本摘要有代码+数据+可直接运行
    10月更新!又一波新功能上线,升级后的EasyOps®简直神了
    【附代码案例】深入理解 PyTorch 张量:叶子张量与非叶子张量
    第一章:初识MySQL
    自然语言生成技术现状调查:核心任务、应用和评估(3)
    P4用软件实现和硬件实现的区别
    JAVA设计模式之模板方法模式
    CocosCreator 面试题(十三)说说Cocos Creator常驻节点
    嵌入式1.1 单片机基础总结
  • 原文地址:https://blog.csdn.net/luoganttcc/article/details/133866024