• 超详细!DALL · E 文生图模型实践指南


    最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。

    我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。

    在这里插入图片描述

    运行环境:Ubuntu服务器

    ⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。



    一、环境配置

    建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。

    使用下述命令新建名为dalle的环境:

    conda create -n dalle python==3.8.0
    
    • 1

    在终端分别运行下述命令,安装所需的python库:

    # 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
    pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    # 安装 dalle特定的库
    pip install dalle-mini
    # 安装 VQGAN
    pip install -q git+https://github.com/patil-suraj/vqgan-jax.git
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    PS:如果由于网络连接问题无法通过pip命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install安装VQGAN即可。


    二、模型下载

    由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。

    DALL · E 模型下载地址:
    mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
    mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main

    VQGAN 模型下载地址:
    https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

    下载完毕后,将模型部署到服务器,注意保存路径。


    三、程序转换

    相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb 将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。

    # dalle-mini
    DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
    DALLE_COMMIT_ID = None
    # VQGAN model
    VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
    VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
    
    import jax
    import jax.numpy as jnp
    
    # check how many devices are available
    jax.local_device_count()
    
    # Load models & tokenizer
    from dalle_mini import DalleBart, DalleBartProcessor
    from vqgan_jax.modeling_flax_vqgan import VQModel
    # Load dalle-mini
    model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
    # Load VQGAN
    vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)
    
    # Model parameters are replicated on each device for faster inference.
    from flax.jax_utils import replicate
    params = replicate(params)
    vqgan_params = replicate(vqgan_params)
    
    # Model functions are compiled and parallelized to take advantage of multiple devices.
    from functools import partial
    
    # model inference
    @partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
    def p_generate(
        tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
    ):
        return model.generate(
            **tokenized_prompt,
            prng_key=key,
            params=params,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            condition_scale=condition_scale,
        )
    
    # decode image
    @partial(jax.pmap, axis_name="batch")
    def p_decode(indices, params):
        return vqgan.decode_code(indices, params=params)
    
    # Keys are passed to the model on each device to generate unique inference per device.
    import random
    
    # create a random key
    seed = random.randint(0, 2**32 - 1)
    key = jax.random.PRNGKey(seed)
    
    # ## 🖍 Text Prompt
    # Our model requires processing prompts.
    
    from dalle_mini import DalleBartProcessor 
    # from transformers import AutoProcessor
    processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
    # Let's define some text prompts
    prompts = [
        "sunset over a lake in the mountains",
        "the Eiffel tower landing on the moon",
    ]
    # print(prompts)
    # Note: we could use the same prompt multiple times for faster inference.
    tokenized_prompts = processor(prompts)
    # Finally we replicate the prompts onto each device.
    tokenized_prompt = replicate(tokenized_prompts)
    
    # ## 🎨 We generate images using dalle-mini model and decode them with the VQGAN.
    
    # number of predictions per prompt
    n_predictions = 8
    
    # We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
    gen_top_k = None
    gen_top_p = None
    temperature = None
    cond_scale = 10.0  # 越高,生成的图像越接近 prompt
    
    from flax.training.common_utils import shard_prng_key
    import numpy as np
    from PIL import Image
    from tqdm.notebook import trange
    
    print(f"Prompts: {prompts}\n")
    # generate images
    images = []
    for i in trange(max(n_predictions // jax.device_count(), 1)):
        # get a new key
        key, subkey = jax.random.split(key)  #  jax.device_count()=1,returns the number of available jax devices
        # generate images
        encoded_images = p_generate(
            tokenized_prompt,
            shard_prng_key(subkey),
            params,
            gen_top_k,
            gen_top_p,
            temperature,
            cond_scale,
        )
        # remove BOS
        encoded_images = encoded_images.sequences[..., 1:]
        decoded_images = p_decode(encoded_images, vqgan_params)
        decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
        
        for idx, decoded_img in enumerate(decoded_images):
            img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
            images.append(img)
    ... 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114

    四、程序运行

    使用命令 python /newdata/SD/inference_dalle-mini.py 运行程序。理想情况下就能够直接得到dalle生成的图像啦!


    五、BUG清除指南

    由于外部环境因素和一些不当操作,本人在运行该程序过程中还是遇到一些问题,主要有三个,在此将抱错信息与解决方法一并分享给大家。

    • 因网络问题导致特定文件下载失败,报错信息如下:
    ...
    requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words-frequency.txt (Caused by ConnectTimeoutError(, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>
        processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained
        return super(PretrainedFromWandbMixin, cls).from_pretrained(
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained
        return cls(tokenizer, config.normalize_text, config.max_text_length)
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__
        self.text_processor = TextNormalizer()
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__
        self._hashtag_processor = HashtagProcessor()
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__
        #     wiki_word_frequency = hf_hub_download(
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
        return fn(*args, **kwargs)
      File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download
        raise LocalEntryNotFoundError(
    huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    顺着上面的报错信息,定位到/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py文件的如下内容:

    ...
    class HashtagProcessor:
        # Adapted from wordninja library
        # We use our wikipedia word count + a good heuristic to make it work
        def __init__(self):
    		wiki_word_frequency = hf_hub_download(
    		    "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
    		)
    		self._word_cost = (
    		    l.split()[0]
    		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
    		)
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    于是问题的根源就在于,程序运行到这里时,没有找到本地的enwiki-words-frequency.txt文件(经检查该文件其实是存在本地的,不知为何没有找到,很迷),于是尝试通过联网从huggingface官网下载,但由于网络状况欠佳,联网失败,于是报错。解决办法如下:

    ...
    class HashtagProcessor:
        # Adapted from wordninja library
        # We use our wikipedia word count + a good heuristic to make it work
        def __init__(self):
    		wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"
    		self._word_cost = (
    		    l.split()[0]
    		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
    		)
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    也就是将enwiki-words-frequency.txt文件的本地路径直接赋值给wiki_word_frequency变量,其余部份保持不变,问题解决。


    • 因安装不当导致的版本冲突问题
    FIx for "Couldn't invoke ptxas --version"
    
    • 1

    这个错误的产生是不同python库安装时带来的版本冲突导致的,DALLE-mini要求jax和jaxlib版本必须为0.3.25,但是通过pip imstall dalle-mini 命令安装后的jaxlib版本为0.4.13,但使用pip install jaxlib的方式并不能找到0.3.25版本的jaxlib,而且会产生与flax、orbax-checkpoint等其他库的版本不兼容问题……在尝试多种方法合理降低jaxlib版本均失败后,发现答案就在ipynb中……也就是:pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

    💡启示:要以官方说明文档为主,可以少走很多弯路!!!


    • 彩蛋:一个非常奇怪的错误:
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>
        decoded_images = p_decode(encoded_images, vqgan_params)
    ValueError: pmap got inconsistent sizes for array axes to be mapped:
      * most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];
      * some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];
      * some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias'] of type float32[256];
      * some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];
      * some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];
      * one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    后来发现,是因为之前调试的时候不小心把下面这行代码注释掉了……这个bug排得最辛苦,还挺无语的😂

    vqgan_params = replicate(vqgan_params)
    
    • 1
    • 因版本限制导致显卡无法正常使用
      由于dalle-mini限制jax和jaxlib的版本只能是0.3.25,因此,无法更新这两个包到最新版本,不知是不是因为这个原因会出现如下报错信息:
    jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
    
    jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_INVALID_IMAGE: device kernel image is invalid
    
    • 1
    • 2
    • 3
    • 程序运行过程中还有一些警告,由下述警告也可以看出jax是属于tensorflow派别的。(我这个程序没有识别出显卡的存在,导致只能在CPU上运行)
    2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
    2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
    2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
    2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
    2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
    WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']
    
      0%|          | 0/8 [00:00<?, ?it/s]
    /root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.
      warnings.warn("scatter inputs have incompatible types: cannot safely cast "
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    后记:第一次接触到基于jax框架编写的程序,还挺新鲜的,感觉和pytorch有一些不一样的地方。了解到jax是tensorflow的轻量级版本。上述博客内容中如果有个人理解不当之处,还望各位批评指正!

    参考链接

    1. python pathlib中Path 的使用(解决不同操作系统的路径问题)_python pathlib.path-CSDN博客
    2. python - vmap gives inconsistent shape error when trying to calculate gradient per sample - Stack Overflow
    3. https://github.com/google/jax/issues/9933
  • 相关阅读:
    leetcode 300. Longest Increasing Subsequence 最长递增子序列 (中等)
    vue&react质检工具(eslint)安装使用总结
    java:Http协议和Tomcat
    判断2叉树是否为对称树(C#)
    ArkTS基础知识
    Android 11 热点(softap)流程分析(二) WifiManager--AIDL
    springboot旅游管理系统的设计与实现毕业设计-附源码261117
    安规电容总结
    x64dbg 基本使用技巧
    第三节:分支及多人协作
  • 原文地址:https://blog.csdn.net/qq_36332660/article/details/134273737