• TorchServe搭建codeBERT分类模型服务


    背景

    最近在做有关克隆代码检测的相关工作,克隆代码是软件开发过程中的常见现象,它在软件开发前期能够提升生产效率,产生一定的正面效益,然而随着系统规模变大,也会产生降低软件稳定性,软件bug传播,系统维护困难等负面作用。本次训练基于codeBERT的分类模型,任务是给定两个函数片段,判断这两个函数片段是否相似,TorchServe主要用于PyTorch模型的部署,现将使用TorchServe搭建克隆代码检测服务过程总结如下。

    TorchServe简介

    TorchServe是部署PyTorch模型服务的工具,由Facebook和AWS合作开发,是PyTorch开源项目的一部分。它可以使得用户更快地将模型用于生产,提供了低延迟推理API,支持模型的热插拔,多模型服务,A/B test版本控制,以及监控指标等功能。TorchServe架构图如下图所示:

    TorchServe框架主要分为四个部分:Frontend是TorchServe的请求和响应的处理部分;Worker Process 指的是一组运行的模型实例,可以由管理API设定运行的数量;Model Store是模型存储加载的地方;Backend用于管理Worker Process。

    codeBERT是什么?

    codeBERT是一个预训练的语言模型,由微软和哈工大发布。我们知道传统的BERT模型是面向自然语言的,而codeBERT是面向自然语言和编程语言的模型,codeBERT可以处理Python,Java,JavaScript等,能够捕捉自然语言和编程语言的语义关系,可以用来做自然语言代码搜索,代码文档生成,代码bug检查以及代码克隆检测等任务。当然我们也可以利用CodeBERT直接提取编程语言的token embeddings,从而进行相关任务。

    环境搭建

    安装TorchServe

    1. pip install torchserve
    2. pip install torch-model-archiever

    编写Handler类

    Handler是我们自定义开发的类,TorchServe运行的时候会执行Handler类,其主要功能就是处理input data,然后通过一系列处理操作返回结果,其中模型的初始化等也是由handler处理。其中Handler类继承自BaseHandler,我们需要重写其中的initialize,preprocess,inference等。

    1. initialize方法
    1. class CloneDetectionHandler(BaseHandler,ABC):
    2. def __int__(self):
    3. super(CloneDetectionHandler,self).__init__()
    4. self.initialized = False
    5. def initialize(self, ctx):
    6. self.manifest = ctx.manifest
    7. logger.info(self.manifest)
    8. properties = ctx.system_properties
    9. model_dir = properties.get("model_dir")
    10. serialized_file = self.manifest['model']['serializedFile']
    11. model_pt_path = os.path.join(model_dir,serialized_file)
    12. self.device = torch.device("cuda:"+str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
    13. config_class, model_class,tokenizer_class = MODEL_CLASSES['roberta']
    14. config = config_class.from_pretrained("microsoft/codebert-base")
    15. config.num_labels = 2
    16. self.tokenizer = tokenizer_class.from_pretrained("microsoft/codebert-base")
    17. self.bert = model_class(config)
    18. self.model = Model(self.bert,config,self.tokenizer)
    19. self.model.load_state_dict(torch.load(model_pt_path))
    20. self.model.to(self.device)
    21. self.model.eval()
    22. logger.info('Clone codeBert model from path {0} loaded successfully'.format(model_dir))
    23. self.initialized = True
    1. preprocess方法
    1. def preprocess(self, requests):
    2. input_batch = None
    3. for idx,data in enumerate(requests):
    4. input_text = data.get("data")
    5. if input_text is None:
    6. input_text = data.get("body")
    7. logger.info("Received codes:'%s'",input_text)
    8. if isinstance(input_text,(bytes,bytearray)):
    9. input_text = input_text.decode('utf-8')
    10. code1 = input_text['code1']
    11. code2 = input_text['code2']
    12. code1 = " ".join(code1.split())
    13. code2 = " ".join(code2.split())
    14. logger.info("code1:'%s'", code1)
    15. logger.info("code2:'%s'", code2)
    16. inputs = self.tokenizer.encode_plus(code1,code2,max_length=512,pad_to_max_length=True, add_special_tokens=True, return_tensors="pt")
    17. input_ids = inputs["input_ids"].to(self.device)
    18. if input_ids.shape is not None:
    19. if input_batch is None:
    20. input_batch = input_ids
    21. else:
    22. input_batch = torch.cat((input_batch,input_ids),0)
    23. return input_batch
    1. inference方法
    1. def inference(self, input_batch):
    2. inferences = []
    3. logits = self.model(input_batch)
    4. num_rows = logits[0].shape[0]
    5. for i in range(num_rows):
    6. out = logits[0][i].unsqueeze(0)
    7. y_hat = out.argmax(0).item()
    8. predicted_idx = str(y_hat)
    9. inferences.append(predicted_idx)
    10. return inferences

    模型打包

    使用toch-model-archiver工具进行打包,将模型参数文件以及其所依赖包打包在一起,在当前目录下会生成mar文件

    1. torch-model-archiver --model-name BERTClass --version 1.0 \
    2. --serialized-file ./CloneDetection.bin \
    3. --model-file ./model.py \
    4. --handler ./handler.py \

    启动服务

    torchserve --start --ncs --model-store ./modelstore --models BERTClass.mar

    服务测试

    1. import requests
    2. import json
    3. diff_codes = {
    4. "code1": " private void loadProperties() {\n if (properties == null) {\n properties = new Properties();\n try {\n URL url = getClass().getResource(propsFile);\n properties.load(url.openStream());\n } catch (IOException ioe) {\n ioe.printStackTrace();\n }\n }\n }\n",
    5. "code2": " public static void copyFile(File in, File out) throws IOException {\n FileChannel inChannel = new FileInputStream(in).getChannel();\n FileChannel outChannel = new FileOutputStream(out).getChannel();\n try {\n inChannel.transferTo(0, inChannel.size(), outChannel);\n } catch (IOException e) {\n throw e;\n } finally {\n if (inChannel != null) inChannel.close();\n if (outChannel != null) outChannel.close();\n }\n }\n"
    6. }
    7. res = requests.post('http://127.0.0.1:8080/predictions/BERTClass",json=diff_codes).text

    第二个请求输入克隆代码对,模型预测结果为1,两段代码段相似,是克隆代码对。克隆代码大体分为句法克隆和语义克隆,本例展示的句法克隆,即对函数名,类名,变量名等重命名,增删部分代码片段还相同的代码对。

    1. clone_codes = {
    2. "code1":" public String kodetu(String testusoila) {\n MessageDigest md = null;\n try {\n md = MessageDigest.getInstance(\"SHA\");\n md.update(testusoila.getBytes(\"UTF-8\"));\n } catch (NoSuchAlgorithmException e) {\n new MezuLeiho(\"Ez da zifraketa algoritmoa aurkitu\", \"Ados\", \"Zifraketa Arazoa\", JOptionPane.ERROR_MESSAGE);\n e.printStackTrace();\n } catch (UnsupportedEncodingException e) {\n new MezuLeiho(\"Errorea kodetzerakoan\", \"Ados\", \"Kodeketa Errorea\", JOptionPane.ERROR_MESSAGE);\n e.printStackTrace();\n }\n byte raw[] = md.digest();\n String hash = (new BASE64Encoder()).encode(raw);\n return hash;\n }\n",
    3. "code2":" private StringBuffer encoder(String arg) {\n if (arg == null) {\n arg = \"\";\n }\n MessageDigest md5 = null;\n try {\n md5 = MessageDigest.getInstance(\"MD5\");\n md5.update(arg.getBytes(SysConstant.charset));\n } catch (Exception e) {\n e.printStackTrace();\n }\n return toHex(md5.digest());\n }\n"
    4. }
    5. res = requests.post('http://127.0.0.1:8080/predictions/BERTClass",json=clone_codes).text

    关闭服务

    torchserve --stop

    总结

    本文主要介绍了如何用TorchServe部署PyTorch模型的流程,首先需要编写hanlder类型文件,然后用torch-model-archiver工具进行模型打包,最后torchserve启动服务,部署流程相对比较简单。

  • 相关阅读:
    nodejs在pdf中绘制表格
    Android车载应用开发之出识Android Automotive
    ISC技术分享:从BAS视角看积极防御体系实践
    Flink 解析kafka avro格式
    JVM篇---第六篇
    数一满分150分总分451东南大学920电子信息通信考研Jenny老师辅导班同学,真题大纲,参考书。
    普通卷积、转置卷积详细介绍以及用法
    Git命令
    Java项目:基于JSP+Servlet的网上订餐管理系统
    代码随想录训练营二刷第三十天 | 332.重新安排行程 51. N皇后 37. 解数独
  • 原文地址:https://blog.csdn.net/cebawuyue/article/details/127410606