• Pytorch人体姿态骨架生成图像


    ControlNet是一个稳定扩散模型,可以复制构图和人体姿势。ControlNet解决了生成想要的确切姿势困难的问题。 Human Pose使用OpenPose检测关键点,如头部、肩膀、手的位置等。它适用于复制人类姿势,但不适用于其他细节,如服装、发型和背景。

    使用方法: 输入一个图像,并提示模型生成一个图像。Openpose将为你检测姿势。

    🔹 本案例需使用Pytorch-1.8 GPU-P100及以上规格运行

    🔹 点击Run in ModelArts,将会进入到ModelArts CodeLab中,这时需要你登录华为云账号,如果没有账号,则需要注册一个,且要进行实名认证, 登录之后,等待片刻,即可进入到CodeLab的运行环境

    1. 环境准备

    为了方便用户下载使用及快速体验,本案例已将代码及control_sd15_openpose预训练模型转存至华为云OBS中。注意:为了使用该模型与权重,你必须接受该模型所要求的License,请访问huggingface的lllyasviel/ControlNet, 仔细阅读里面的License。模型下载与加载需要几分钟时间。

    1. import os
    2. import moxing as mox
    3. parent = os.path.join(os.getcwd(),'ControlNet')
    4. if not os.path.exists(parent):
    5. mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/ControlNet/ControlNet',parent)
    6. if os.path.exists(parent):
    7. print('Code Copy Completed.')
    8. else:
    9. raise Exception('Failed to Copy the Code.')
    10. else:
    11. print("Code already exists!")
    12. pose_model_path = os.path.join(os.getcwd(),"ControlNet/models/control_sd15_openpose.pth")
    13. body_model_path = os.path.join(os.getcwd(),"ControlNet/annotator/ckpts/body_pose_model.pth")
    14. hand_model_path = os.path.join(os.getcwd(),"ControlNet/annotator/ckpts/hand_pose_model.pth")
    15. if not os.path.exists(pose_model_path):
    16. mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/ControlNet/ControlNet_models/control_sd15_openpose.pth',pose_model_path)
    17. mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/ControlNet/ControlNet_models/body_pose_model.pth',body_model_path)
    18. mox.file.copy_parallel('obs://modelarts-labs-bj4-v2/case_zoo/ControlNet/ControlNet_models/hand_pose_model.pth',hand_model_path)
    19. if os.path.exists(pose_model_path):
    20. print('Models Download Completed')
    21. else:
    22. raise Exception('Failed to Copy the Models.')
    23. else:
    24. print("Model Packages already exists!")

    check GPU & 安装依赖

    大约耗时1min

    1. !nvidia-smi
    2. %cd ControlNet
    3. !pip uninstall torch torchtext -y
    4. !pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1
    5. !pip install omegaconf==2.1.1 einops==0.3.0
    6. !pip install pytorch-lightning==1.5.0
    7. !pip install transformers==4.19.2 open_clip_torch==2.0.2
    8. !pip install gradio==3.24.1
    9. !pip install translate==3.6.1
    10. !pip install scikit-image==0.19.3
    11. !pip install basicsr==1.4.2

    导包

    1. import config
    2. import cv2
    3. import einops
    4. import gradio as gr
    5. import numpy as np
    6. import torch
    7. import random
    8. from pytorch_lightning import seed_everything
    9. from annotator.util import resize_image, HWC3
    10. from annotator.openpose import OpenposeDetector
    11. from cldm.model import create_model, load_state_dict
    12. from cldm.ddim_hacked import DDIMSampler
    13. from translate import Translator
    14. from PIL import Image
    15. import matplotlib.pyplot as plt

    2. 加载模型

    1. apply_openpose = OpenposeDetector()
    2. model = create_model('./models/cldm_v15.yaml').cpu()
    3. model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda'))
    4. model = model.cuda()
    5. ddim_sampler = DDIMSampler(model)

    3. 人体姿态生成图像

    1. def infer(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
    2. trans = Translator(from_lang="ZH",to_lang="EN-US")
    3. prompt = trans.translate(prompt)
    4. a_prompt = trans.translate(a_prompt)
    5. n_prompt = trans.translate(n_prompt)
    6. # 图像预处理
    7. with torch.no_grad():
    8. if type(input_image) is str:
    9. input_image = np.array(Image.open(input_image))
    10. input_image = HWC3(input_image)
    11. detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
    12. detected_map = HWC3(detected_map)
    13. img = resize_image(input_image, image_resolution)
    14. H, W, C = img.shape
    15. # 初始化检测映射
    16. detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
    17. control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
    18. control = torch.stack([control for _ in range(num_samples)], dim=0)
    19. control = einops.rearrange(control, 'b h w c -> b c h w').clone()
    20. # 设置随机种子
    21. if seed == -1:
    22. seed = random.randint(0, 65535)
    23. seed_everything(seed)
    24. if config.save_memory:
    25. model.low_vram_shift(is_diffusing=False)
    26. cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    27. un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    28. shape = (4, H // 8, W // 8)
    29. if config.save_memory:
    30. model.low_vram_shift(is_diffusing=True)
    31. # 采样
    32. model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
    33. samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
    34. shape, cond, verbose=False, eta=eta,
    35. unconditional_guidance_scale=scale,
    36. unconditional_conditioning=un_cond)
    37. if config.save_memory:
    38. model.low_vram_shift(is_diffusing=False)
    39. # 后处理
    40. x_samples = model.decode_first_stage(samples)
    41. x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
    42. results = [x_samples[i] for i in range(num_samples)]
    43. return [detected_map] + results

    设置参数,生成图像

    上传您的图像至./ControlNet/test_imgs/ 路径下,然后更改图像路径及其他参数后,点击运行。

    ➡参数说明:

    🔸 img_path:输入图像路径,黑白稿

    🔸 prompt:提示词

    🔸 a_prompt:次要的提示

    🔸 n_prompt: 负面提示,不想要的内容

    🔸 image_resolution: 对输入的图片进行最长边等比resize

    🔸 detect_resolution: 中间生成条件图像的分辨率

    🔸 scale:文本提示的控制强度,越大越强

    🔸 guess_mode: 盲猜模式,默认关闭,开启后生成图像将不受prompt影响,使用更多样性的结果,生成后得到不那么遵守图像条件的结果

    🔸 seed: 随机种子

    🔸 ddim_steps: 采样步数,一般15-30,值越大越精细,耗时越长

    🔸 DDIM eta: 生成过程中的随机噪声系数,一般选0或1,1表示有噪声更多样,0表示无噪声,更遵守描述条件

    🔸 strength: 这是应用 ControlNet 的步骤数。它类似于图像到图像中的去噪强度。如果指导强度为 1,则 ControlNet 应用于 100% 的采样步骤。如果引导强度为 0.7 并且您正在执行 50 个步骤,则 ControlNet 将应用于前 70% 的采样步骤,即前 35 个步骤。

    1. #@title ControlNet-OpenPose
    2. img_path = "test_imgs/pose1.png" #@param {type:"string"}
    3. prompt = "优雅的女士" #@param {type:"string"}
    4. seed = 1685862398 #@param {type:"slider", min:-1, max:2147483647, step:1}
    5. guess_mode = False #@param {type:"raw", dropdown}
    6. a_prompt = '质量最好,非常详细'
    7. n_prompt = '长体,下肢,解剖不好,手不好,手指缺失,手指多,手指少,裁剪,质量最差,质量低'
    8. num_samples = 1
    9. image_resolution = 512
    10. detect_resolution = 512
    11. ddim_steps = 20
    12. strength = 1.0
    13. scale = 9.0
    14. eta = 0.0
    15. np_imgs = infer(img_path, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta)
    16. ori = Image.open(img_path)
    17. src = Image.fromarray(np_imgs[0])
    18. dst = Image.fromarray(np_imgs[1])
    19. fig = plt.figure(figsize=(25, 10))
    20. ax1 = fig.add_subplot(1, 3, 1)
    21. plt.title('Orginal image', fontsize=16)
    22. ax1.axis('off')
    23. ax1.imshow(ori)
    24. ax2 = fig.add_subplot(1, 3, 2)
    25. plt.title('Pose image', fontsize=16)
    26. ax2.axis('off')
    27. ax2.imshow(src)
    28. ax3 = fig.add_subplot(1, 3, 3)
    29. plt.title('Generate image', fontsize=16)
    30. ax3.axis('off')
    31. ax3.imshow(dst)
    32. plt.show()

    上传自己的照片,输入你的prompt提示词

    运行结果:

    4. Gradio可视化部署

    Gradio应用启动后可在下方页面上传图片根据提示生成图像,您也可以分享public url在手机端,PC端进行访问生成图像。

    请注意: 在图像生成需要消耗显存,您可以在左侧操作栏查看您的实时资源使用情况,点击GPU显存使用率即可查看,当显存不足时,您生成图像可能会报错,此时,您可以通过重启kernel的方式重置,然后重头运行即可规避。

    Image

    1. block = gr.Blocks().queue()
    2. with block:
    3. with gr.Row():
    4. gr.Markdown("## 💃人体姿态生成图像")
    5. with gr.Row():
    6. with gr.Column():
    7. gr.Markdown("请上传一张人像图,设置好参数后,点击Run")
    8. input_image = gr.Image(source='upload', type="numpy")
    9. prompt = gr.Textbox(label="描述")
    10. run_button = gr.Button(label="Run")
    11. with gr.Accordion("高级选项", open=False):
    12. num_samples = gr.Slider(label="Images", minimum=1, maximum=3, value=1, step=1)
    13. image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
    14. strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
    15. guess_mode = gr.Checkbox(label='Guess Mode', value=False)
    16. detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=1024, value=512, step=1)
    17. ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=30, value=20, step=1)
    18. scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
    19. seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
    20. eta = gr.Number(label="eta (DDIM)", value=0.0)
    21. a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
    22. n_prompt = gr.Textbox(label="Negative Prompt",
    23. value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
    24. with gr.Column():
    25. result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
    26. ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
    27. run_button.click(fn=infer, inputs=ips, outputs=[result_gallery])
    28. block.launch(share=True)
    1. INFO:botocore.vendored.requests.packages.urllib3.connectionpool:Starting new HTTP connection (1): proxy.modelarts.com
    2. INFO:botocore.vendored.requests.packages.urllib3.connectionpool:Starting new HTTPS connection (1): www.huaweicloud.com
    3. Running on local URL: http://127.0.0.1:7860
    4. Running on public URL: https://96b421e81ebf0fe302.gradio.live
    5. This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces

  • 相关阅读:
    数据库范式
    方法递归(黑马)
    基于springboot+vue的云南旅游网(前后端分离)
    PostgreSQL 逻辑复制模块(一)
    抖音关键词搜索商品-API工具
    Git使用教程
    grid 布局 grid-column-gap 使用后内容超出网格
    c# 项目重构,创建新的解决方案
    入职中国平安三周年的一些总结
    【华为OD机试】服务失效判断【2023 B卷|200分】
  • 原文地址:https://blog.csdn.net/chengxuyuanlaow/article/details/140961818