• 飞桨模型部署至docker并使用FastAPI调用(三)-API部署


    文章首发及后续更新:https://mwhls.top/4085.html,无图/无目录/格式错误/更多相关请至首发页查看。
    新的更新内容请到mwhls.top查看。
    欢迎提出任何疑问及批评,非常感谢!

    飞桨模型部署至docker并使用FastAPI调用

    本地 get 调用 fastapi

    • 中间的调错和测试省略了,只展示最终结果,毕竟环境会骗你,代码不会。

    • 运行 startup.py,并访问 http://127.0.0.1:8000/,终端输出如下

      INFO:     Will watch for changes in these directories: ['/root/code']
      INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
      INFO:     Started reloader process [21292] using statreload
      /usr/local/lib/python3.8/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: np.object is a deprecated alias for the builtin object. To silence this warning, use object by itself. Doing this will not modify any behavior and is safe. 
      Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
      if data.dtype == np.object:
      INFO:     Started server process [21294]
      INFO:     Waiting for application startup.
      INFO:     Application startup complete.
      2022-06-18 07:26:02 [WARNING]   Cannot find raw_params. Default arguments will be used to construct the model.
      2022-06-18 07:26:02 [INFO]      Model[BIT] loaded.
      ------------------ Inference Time Info ----------------------
      total_time(ms): 4260.3, img_num: 1, batch_size: 1
      average latency time(ms): 4260.30, QPS: 0.234725
      preprocess_time_per_im(ms): 121.90, inference_time_per_batch(ms): 4132.60, postprocess_time_per_im(ms): 5.80
      INFO:     127.0.0.1:32956 - "GET / HTTP/1.1" 200 OK
      INFO:     127.0.0.1:32956 - "GET /favicon.ico HTTP/1.1" 404 Not Found
            • 1
            • 2
            • 3
            • 4
            • 5
            • 6
            • 7
            • 8
            • 9
            • 10
            • 11
            • 12
            • 13
            • 14
            • 15
            • 16
          • 网页展示如下(仅展示部分 base64):

            {"message":"iVBORw0KGgoAAAANSUhEUgAAAoAAAAH...SK+Z8VWmji1wgxWwAAAABJRU5ErkJggg=="}
            目录树 – docker
            root
            └─ code
               ├─ datasets
               │  └─ infer
               │     ├─ before.png
               │     ├─ label_no_use.png
               │     └─ later.png
               ├─ inference_model
               │  ├─ .success
               │  ├─ model.pdiparams
               │  ├─ model.pdiparams.info
               │  ├─ model.pdmodel
               │  ├─ model.yml
               │  └─ pipeline.yml
               ├─ main.py
               ├─ predict.py
               └─ startup.py
            • 1
            • 2
            • 3
            • 4
            • 5
            • 6
            • 7
            • 8
            • 9
            • 10
            • 11
            • 12
            • 13
            • 14
            • 15
            • 16
            • main.py – fastapi 代码
            # main.py
            from fastapi import FastAPI
            from predict import predict
            import base64
            
            • 1
            • 2
            • 3
            • 4

            app = FastAPI()
            img_before_base64 = base64.b64encode(open(“/root/code/datasets/infer/before.png”, “rb”).read()).decode(‘utf8’)
            img_after_base64 = base64.b64encode(open(“/root/code/datasets/infer/later.png”, “rb”).read()).decode(‘utf8’)

            @app.get(‘/’)
            def index():
            img_variation_base64 = predict(img_before_base64, img_after_base64)
            return {‘message’: img_variation_base64}

            predict.py – 模型推理代码
            • 中英命名有点小冲突,predict是预测,inference是推理,但是参考的几篇里面,都是讲推理,但调用的是 predict,不过按我现在的理解,这俩玩意是一个东西,问题不大。
            • Copilot 帮我写了一些直接转换的函数,然而不知道咋回事结果不是很正确,只好用临时文件解决,感觉降低了不少性能。
            from paddlers.deploy import Predictor
            from PIL import Image
            from matplotlib import pyplot as plt
            from io import BytesIO
            import base64
            import tempfile
            
            • 1
            • 2
            • 3
            • 4
            • 5
            • 6

            def base64_to_img(img_base64):
            # base64 to PIL img
            img_pil = Image.open(BytesIO(base64.b64decode(img_base64)))

            # PIL save as tmp file
            img_tf = tempfile.NamedTemporaryFile()
            img_pil.save(img_tf, format='png')
            
            return img_tf, img_pil
            
            • 1
            • 2
            • 3
            • 4
            • 5

            def predict(img_before_base64, img_after_base64):
            # ref: https://aistudio.baidu.com/aistudio/projectdetail/4184759
            # build predictor
            predictor = Predictor(“/root/code/inference_model”, use_gpu=False)

            # base64 to tmp file and PIL img
            img_before_tf, img_before_pil = base64_to_img(img_before_base64)
            img_after_tf, img_after_pil = base64_to_img(img_after_base64)
            
            # predict
            res_pred = predictor.predict((img_before_tf.name, img_after_tf.name))[0]['label_map']
            
            # result to PIL img
            img_variation_pil = Image.fromarray(res_pred * 255)
            
            # show with before and after
            plt.figure(constrained_layout=True);  
            plt.subplot(131);  plt.imshow(img_before_pil);  plt.gca().set_axis_off();  plt.title("Before")
            plt.subplot(132);  plt.imshow(img_after_pil);  plt.gca().set_axis_off();  plt.title("After")
            plt.subplot(133);  plt.imshow(img_variation_pil);  plt.gca().set_axis_off();  plt.title("Pred")
            img_variation_tf = tempfile.NamedTemporaryFile()
            plt.savefig(img_variation_tf)
            
            # plt to base64
            img_variation_base64 = base64.b64encode(open(img_variation_tf.name, "rb").read()).decode('utf8')
            
            # close tmp file
            img_before_tf.close()
            img_after_tf.close()
            img_variation_tf.close()
            
            return img_variation_base64
            • 1
            • 2
            • 3
            • 4
            • 5
            • 6
            • 7
            • 8
            • 9
            • 10
            • 11
            • 12
            • 13
            • 14
            • 15
            • 16
            • 17
            • 18
            • 19
            • 20
            • 21
            • 22
            • 23
            • 24
            • 25
            • 26
            • 27
            startup.py – 启动 fastapi 服务
            import os
            
            • 1

            os.chdir(‘/root/code’)
            path_main = ‘main’
            command = f’uvicorn {path_main}:app --reload’
            os.system(command)

            本地 Python post 调用 fastapi

            • fastapi 能所有文件都能热更新,NB。

            • 运行 startup.py 启动 docker 中的 fastapi 服务。

            • 运行 post.py 在宿主机中调用 docker 中的 fastapi 服务。

            • docker 终端输出如下:

              INFO:     Will watch for changes in these directories: ['/root/code']
              INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
              INFO:     Started reloader process [10056] using statreload
              /usr/local/lib/python3.8/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: np.object is a deprecated alias for the builtin object. To silence this warning, use object by itself. Doing this will not modify any behavior and is safe. 
              Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
              if data.dtype == np.object:
              INFO:     Started server process [10058]
              INFO:     Waiting for application startup.
              INFO:     Application startup complete.
              2022-06-18 13:19:13 [WARNING]   Cannot find raw_params. Default arguments will be used to construct the model.
              2022-06-18 13:19:13 [INFO]      Model[BIT] loaded.
              ------------------ Inference Time Info ----------------------
              total_time(ms): 4365.2, img_num: 1, batch_size: 1
              average latency time(ms): 4365.20, QPS: 0.229085
              preprocess_time_per_im(ms): 127.90, inference_time_per_batch(ms): 4233.30, postprocess_time_per_im(ms): 4.00
              INFO:     127.0.0.1:34572 - "POST /predict HTTP/1.1" 200 OK
                    • 1
                    • 2
                    • 3
                    • 4
                    • 5
                    • 6
                    • 7
                    • 8
                    • 9
                    • 10
                    • 11
                    • 12
                    • 13
                    • 14
                    • 15
                  • 宿主机终端输出如下:

                    post consume: 5.9285101890563965, all consume: 8.314009189605713
                    done
                    • 1
                  • 代码所在目录生成推理结果图片:

                    • pred.png(分辨率为 1024×1024,位深度为 8bit,与训练集中标签一致)
                  目录树 – 宿主机
                  test
                  ├─ before.png
                  ├─ label_no_use.png
                  ├─ later.png
                  ├─ post.py
                  └─ pred.png
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5
                  post.py – 宿主机中,本地 python 使用 post 访问部署于 docker 中的 fastapi
                  # post.py
                  import requests
                  import base64
                  import time
                  
                  • 1
                  • 2
                  • 3
                  • 4

                  def post_predict(img_before_base64, img_after_base64):
                  url = ‘http://localhost:8000/predict’
                  data = {‘img_before_base64’: img_before_base64, ‘img_after_base64’: img_after_base64}
                  res = requests.post(url, json=data)
                  return res.json()

                  def main():
                  # test files
                  img_before_base64 = base64.b64encode(open(“./before.png”, “rb”).read()).decode(‘utf8’)
                  img_after_base64 = base64.b64encode(open(“./later.png”, “rb”).read()).decode(‘utf8’)

                  # post to predict
                  time_start = time.time()
                  result_post = post_predict(img_before_base64, img_after_base64)
                  img_variation_base64 = result_post['img_variation_base64']
                  time_consume = result_post['time_consume']
                  
                  # output
                  with open('./pred.png', 'wb') as f:
                      f.write(base64.b64decode(img_variation_base64))
                  print(f'post consume: {time_consume}, all consume: {time.time() - time_start}')
                  print('done')
                  
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5
                  • 6
                  • 7
                  • 8
                  • 9
                  • 10
                  • 11

                  if name == ‘main’:
                  main()

                  main.py – fastapi 代码
                  • 新增 post 方法,删除了一些测试代码
                  # main.py
                  from fastapi import FastAPI
                  from predict import predict
                  from pydantic import BaseModel
                  import time
                  
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5

                  class PredictRequest(BaseModel):
                  img_before_base64: str
                  img_after_base64: str

                  app = FastAPI()

                  @app.get(‘/’)
                  def index():
                  # index
                  return ‘running’

                  @app.post(‘/predict’)
                  def predict_post(request: PredictRequest):
                  # predict by post
                  time_start = time.time()
                  img_variation_base64 = predict(request.img_before_base64, request.img_after_base64)
                  time_consume = time.time() - time_start
                  return {‘img_variation_base64’: img_variation_base64, ‘time_consume’: time_consume}

                  predict.py – 推理代码
                  • 稍微优化了下,但这里肯定还能优化,临时文件肯定是没必要的,不过还是先跑起来再说。
                  # predict.py
                  from paddlers.deploy import Predictor
                  from PIL import Image
                  from matplotlib import pyplot as plt
                  from io import BytesIO
                  import base64
                  import tempfile
                  
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5
                  • 6
                  • 7

                  def base64_to_img(img_base64):
                  # convert base64 to tmp file and PIL img
                  # base64 to PIL img
                  img_pil = Image.open(BytesIO(base64.b64decode(img_base64)))

                  # PIL save as tmp file
                  img_tf = tempfile.NamedTemporaryFile()
                  img_pil.save(img_tf, format='png')
                  
                  return img_tf, img_pil
                  
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5

                  def predict(img_before_base64, img_after_base64):
                  # predict the variation field from two images, and return the base64 of variation field
                  # build predictor
                  predictor = Predictor(“/root/code/inference_model”, use_gpu=False)

                  # base64 to tmp file and PIL img
                  img_before_tf, _ = base64_to_img(img_before_base64)
                  img_after_tf, _ = base64_to_img(img_after_base64)
                  
                  # predict
                  res_pred = predictor.predict((img_before_tf.name, img_after_tf.name))[0]['label_map'] * 255
                  
                  # result to PIL img
                  img_variation_pil = Image.fromarray(res_pred).convert('L')
                  
                  # save PIL img
                  img_variation_tf = tempfile.NamedTemporaryFile()
                  img_variation_pil.save(img_variation_tf, format='png')
                  img_variation_base64 = base64.b64encode(open(img_variation_tf.name, "rb").read()).decode('utf8')
                  
                  # close tmp file
                  img_before_tf.close()
                  img_after_tf.close()
                  img_variation_tf.close()
                  
                  return img_variation_base64
                  • 1
                  • 2
                  • 3
                  • 4
                  • 5
                  • 6
                  • 7
                  • 8
                  • 9
                  • 10
                  • 11
                  • 12
                  • 13
                  • 14
                  • 15
                  • 16
                  • 17
                  • 18
                  • 19
                  • 20
                  • 21
                  startup.py – 启动 fastapi 服务
                  # startup.py
                  import os
                  
                  • 1
                  • 2

                  os.chdir(‘/root/code’)
                  path_main = ‘main’
                  command = f’uvicorn {path_main}:app --reload’
                  os.system(command)

                  参考文献

                  1. 将matplotlib绘制的图形直接以base64格式传递到html使用
                  2. Renderer problems using Matplotlib from within a script
                  3. Python调用get或post请求外部接口
                  4. 去除plt.savefig()的白边
                  5. 数据可视化–matplotlib图表控制输出分辨率
                  6. PIL模块储存单值图为png格式后单值图全黑的情况的解决方案
                • 相关阅读:
                  深入聊聊java中判等问题:你真的会用==和equals吗
                  halcon测量
                  Android网络性能监控方案 android线上性能监测
                  Java版工程行业管理系统源码-专业的工程管理软件- 工程项目各模块及其功能点清单
                  修改hugo新建post的默认信息(front matter)
                  WebAssembly 概述
                  AI时代你一定要知道的Agent概念
                  《 新手》web前端(axios)后端(java-springboot)对接简解
                  MySQL 的执行原理(一)
                  Vue 3响应式系统全解析:深入ref、reactive、computed、watch及watchEffect
                • 原文地址:https://blog.csdn.net/asd123pwj/article/details/128065134