• pytorch模型网页部署——Flask


    一、Flask用法

    Flask是python的轻量级web框架,可用来做简单的模型部署。Flask的基本用法如下:

    step1:定义Flask类的对象,即创建一个基于Flask的服务器

    step2:定义公开的路由及路由对应的调用函数

    step3:运行服务器

    1. """基于flask的web网页"""
    2. from flask import Flask # 导入flask库
    3. app = Flask(__name__) # 创建Flask类的对象,可理解为建立一个基于flask框架的服务器
    4. # 公开路由的名称【my_fcn】,同时修饰下一行定义的函数。
    5. # 定义的函数名要与公开的路由名称一致。
    6. # 后续访问网页的url格式为:http://ip:port/路由名称
    7. @app.route("/my_fcn")
    8. def my_fcn():
    9. return "hello world" # 访问网页时返回内容会显示在网页上
    10. if __name__ == "__main__":
    11. app.run(host='0.0.0.0', port=8000) # 运行服务器。可通过get/post参数请求数据

    运行结果:

    二、在基于flask的网页上部署模型

    在基于flask的网页上部署模型,其实只需在上述例子中定义的函数【my_fcn】中添加模型预测的代码即可。示例如下:

    1. """基于flask的web网页"""
    2. import numpy as np
    3. import cv2
    4. import torch
    5. import torchvision
    6. from flask import Flask # 导入flask库
    7. app = Flask(__name__) # 创建Flask类的对象,可理解为建立一个基于flask框架的服务器
    8. # 公开路由的名称【my_fcn】,同时修饰下一行定义的函数。
    9. # 定义的函数名要与公开的路由名称一致。
    10. # 后续访问网页的url格式为:http://ip:port/路由名称
    11. @app.route("/my_fcn")
    12. def my_fcn():
    13. # 加载图片,并将图片转为0到1之间的浮点数张量
    14. img = cv2.imread("rose.jpg")
    15. img = cv2.resize(img, (224, 224))
    16. img_tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()/255.0
    17. model = torchvision.models.resnet50(pretrained=True)
    18. model.eval()
    19. output = model(img_tensor)
    20. output = torch.nn.functional.softmax(output,1)
    21. output = torch.argmax(output)
    22. return "class index={}".format(output.numpy())
    23. if __name__ == "__main__":
    24. app.run(host='0.0.0.0', port=8000) # 运行服务器。可通过get/post参数请求数据

    运行结果:

    三、远程客户端访问网页服务器模型进行推理 

    当需要在远程客户端请求服务器进行推理时,需要将图像的数据post到服务器,第二节的方法就需要进行改进了,具体方法是将图像数据编码为二进制以post方法提交,服务器解析后进行推理。

    举例:web服务器部署了基于resnet50的分类模型,远程客户端读取了一张图片,并提交给服务器进行推理,服务器将推理结果返回给客户端。

    代码如下:

    3.1 服务器代码:server.py

    1. # server.py
    2. import numpy as np
    3. from flask import Flask, request, jsonify
    4. import json
    5. import torch
    6. import torchvision
    7. app = Flask(__name__)
    8. model = torchvision.models.resnet50(pretrained=True)
    9. model.eval()
    10. # 推理过程
    11. def run_inference(in_tensor):
    12. with torch.no_grad():
    13. out_tensor = model(in_tensor)
    14. out_tensor = torch.nn.functional.softmax(out_tensor, 1)
    15. output = torch.argmax(out_tensor)
    16. return output
    17. # flask服务器
    18. @app.route('/predict', methods=['GET', 'POST']) # 注意,这里开放GET、POST请求
    19. def predict():
    20. # 客户端post时,包含将输入张量尺寸以json字符串,input_input = "{"shape": [C, W, H]}",因此需要重新解析为json并提取尺寸
    21. in_shape = json.load(request.files['in_shape'])
    22. # 客户端post时,将图像数据以二进制形式发送,因此需要将图像重新从二进制转换会tensor,并resize
    23. in_blob = request.files['in_blob'].read()
    24. in_tensor = torch.from_numpy(np.frombuffer(in_blob, dtype=np.float32))
    25. in_tensor = in_tensor.view(*in_shape['shape'])
    26. output = run_inference(in_tensor) # 推理
    27. output = '{}'.format(output)
    28. return jsonify(output) # 以json返回结果
    29. if __name__ == "__main__":
    30. app.run(host='0.0.0.0', port=8000)

    3.2 客户端代码:client.py

    1. # 首先运行【server.py】,然后运行本文件,实现请求推理本地图像的预测结果。
    2. # client.py
    3. import torch
    4. import cv2
    5. import io
    6. import json
    7. import requests
    8. # config
    9. IMG_NAME = 'dog.jpg'
    10. MY_URL = 'http://192.168.1.103:8000/predict'
    11. # 处理图像数据
    12. img = cv2.imread(IMG_NAME)
    13. img = cv2.resize(img, (224, 224)) # resnet输入张量shape=3x224x224
    14. in_tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()/255.0
    15. blob = io.BytesIO(bytearray(in_tensor.numpy())) # 将输入张量转为numpy,再转为二进制进行post请求
    16. shape = io.StringIO(json.dumps({'shape': [1, 3, 224, 224]})) # 将输入张量尺寸post
    17. my_files = {'in_shape': shape, 'in_blob': blob}
    18. r = requests.post(url=MY_URL, files=my_files)
    19. response = json.loads(r.content) # 由于服务器返回结果为json,因此需要解析json内容
    20. print("the class index of '{}' is: {}".format(IMG_NAME, response))

    运行结果:

    四、缺点

    由于HTTP是串行的,当大量并发请求时,这种方式只能应答完一个请求后才会应答下一个,改进方法可通过使用Sanic框架,实现异步并行处理。

  • 相关阅读:
    c++ 条件变量使用详解 wait_for wait_unitl 虚假唤醒
    从 40% 跌至 4%,“糊”了的 Firefox 还能重回巅峰吗?
    [Qt基础内容-08] Qt中MVC的M(Model)
    五、分类算法 总结
    Spring实现简单的Bean容器
    JVM内存模型
    Python Json 处理,列表生成式的项目使用 笔记
    pmp新考纲全真模拟题,提分敏捷+情景
    勇士大战恶魔?这款桌游明明是套高质量原创手办
    windows11 利用vmware17 安装rocky9操作系统
  • 原文地址:https://blog.csdn.net/wxyczhyza/article/details/128116921