• DALL·E 2 文生图模型实践指南


    前言:本篇博客记录使用dalle2模型进行推断时借鉴的相关资料和DEBUG流程。

    相关博客:超详细!DALL · E 文生图模型实践指南


    在这里插入图片描述


    1. 环境搭建和预训练模型准备

    本文使用的代码仓库为:https://github.com/lucidrains/DALLE2-pytorch

    环境搭建

    pip install dalle2-pytorch
    
    • 1

    预训练模型下载

    地址:https://huggingface.co/laion/DALLE2-PyTorch

    2. 代码

    DALLE2 for inference 完整推断流程如下(from @cest_andre in Issues#282):

    import torch
    from torchvision.transforms import ToPILImage
    from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
    from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
    
    
    prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
    prior = prior_config.create().cuda()
    
    prior_model_state = torch.load("weights/prior_latest.pth")
    prior.load_state_dict(prior_model_state, strict=True)
    
    decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
    decoder = decoder_config.create().cuda()
    
    decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]
    
    for k in decoder.clip.state_dict().keys():
        decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
    
    decoder.load_state_dict(decoder_model_state, strict=True)
    
    dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()
    
    images = dalle2(
        ['your prompt here'],
        cond_scale = 2.
    ).cpu()
    
    print(images.shape)
    
    for img in images:
        img = ToPILImage()(img)
        img.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    3. BUG&DEBUG

    URLError

    报错信息如下:

    Traceback (most recent call last):
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
        h.request(req.get_method(), req.selector, req.data, headers,
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
        self._send_request(method, url, body, headers, encode_chunked)
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1301, in _send_request
        self.endheaders(body, encode_chunked=encode_chunked)
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1250, in endheaders
        self._send_output(message_body, encode_chunked=encode_chunked)
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1010, in _send_output
        self.send(msg)
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 950, in send
        self.connect()
      File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1424, in connect
        self.sock = self._context.wrap_socket(self.sock,
      File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 500, in wrap_socket
        return self.sslsocket_class._create(
      File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1040, in _create
        self.do_handshake()
      File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1309, in do_handshake
        self._sslobj.do_handshake()
    ConnectionResetError: [Errno 104] Connection reset by peer
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "/newdata/SD/extra/dalle2_cest.py", line 11, in <module>
        prior = prior_config.create().cuda()
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 185, in create
        clip = self.clip.create()
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 122, in create
        return OpenAIClipAdapter(self.model)
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 313, in __init__
        openai_clip, preprocess = clip.load(name)
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 122, in load
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 59, in _download
        with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 222, in urlopen
        return opener.open(url, data, timeout)
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 525, in open
        response = self._open(req, data)
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 542, in _open
        result = self._call_chain(self.handle_open, protocol, protocol +
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 502, in _call_chain
        result = func(*args)
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1393, in https_open
        return self.do_open(http.client.HTTPSConnection, req,
      File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1353, in do_open
        raise URLError(err)
    urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    我使用的是https://github.com/lucidrains/DALLE2-pytorch这个网址。

    找到 /root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py 中对应的位置,我这里是第1349行,修改方式也在下面代码中一并给出。

    try:
        h.request(req.get_method(), req.selector, req.data, headers,
                  encode_chunked=req.has_header('Transfer-encoding'))
        time.sleep(0.5)  # 添加的一行
    except OSError as err: # timeout error
        raise URLError(err)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    CUDA error

    RuntimeError: CUDA error: no kernel image is available for execution on the device
    CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
    For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
    
    • 1
    • 2
    • 3

    解决方案:版本不匹配,更换与系统cuda相匹配的pytorch版本。比如我的cuda版本是12.0,可以使用如下命令安装pytorch:

    pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
    
    • 1

    RuntimeError

    Traceback (most recent call last):
      File "/newdata/SD/extra/dalle2_cest.py", line 14, in <module>
        prior.load_state_dict(prior_model_state, strict=True)
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    RuntimeError: Error(s) in loading state_dict for DiffusionPrior:
            Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". 
            Unexpected key(s) in state_dict: "net.null_text_embed". 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    解决方案1️⃣:load_state_dict()函数中的 strict=True 改为 strict=False,如下:

    ...
    prior.load_state_dict(prior_model_state, strict=False)
    
    decoder.load_state_dict(decoder_model_state, strict=False)
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5

    但这种方法可能会导致模型的性能下降,而且会生成马赛克图像,这显然不是我们想要的结果。

    在这里插入图片描述

    解决方案2️⃣:参考Issues中cest-andre的答案

    步骤(1)降低dalle2_pytorch版本至1.1.0

    pip install dalle2-pytorch==1.1.0
    
    • 1

    步骤(2):版本降低后,要修复dalle2_pytorch.py文件中一个小bug:将第2940行改为如下代码:

    images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
    
    • 1

    PydanticUserError

    降低dalle2_pytorch版本后,运行程序报错如下:

    Traceback (most recent call last):
      File "/newdata/SD/extra/dalle2_cest.py", line 8, in <module>
        from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 34, in <module>
        class TrainSplitConfig(BaseModel):
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 40, in TrainSplitConfig
        def validate_all(cls, fields):
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 222, in root_validator
        return root_validator()(*__args)  # type: ignore
      File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 228, in root_validator
        raise PydanticUserError(
    pydantic.errors.PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    解决方案:参考Issues中JasbirCodeSpace的答案,降低Pydantic的版本:

    pip install pydantic==1.10.6
    
    • 1

    到这里,模型就可以完成推断过程啦~嘻嘻!以下是prompt为a red car时生成的图像:

    在这里插入图片描述

    后记:感谢前人铺路!🌹


    参考链接

    1. https://github.com/lucidrains/DALLE2-pytorch/issues/282
    2. python requests请求报错ConnectionError: (‘Connection aborted.‘, error(104, ‘Connection reset by peer‘))_铁朵斯提的博客-CSDN博客
    3. GPU版本pytorch(Cuda12.1)清华源快速安装一步一步教!小白教学~_清华源安装torch-CSDN博客
  • 相关阅读:
    A study of graph-based system for multi-view clustering
    Python函数的参数与返回值
    为什么要做高新?高新技术企业和科技型企业的区别?
    87.(cesium之家)cesium热力图(贴地形)
    【Mybatis编程:统计相册表中的数据的数量】
    一文搞懂穷举算法
    git撤销还未push的的提交
    【C语言学习】易混淆知识点
    基于Python和mysql开发的智慧校园答题考试系统(源码+数据库+程序配置说明书+程序使用说明书)
    【Verilog基础】【计算机体系架构】一文搞懂 Cache 缓存(cache line、标记Tag、组号/行号index,块内地址offset)
  • 原文地址:https://blog.csdn.net/qq_36332660/article/details/134386502