• Stable Diffusion (version x.x) 文生图模型实践指南


    前言:本篇博客记录使用Stable Diffusion模型进行推断时借鉴的相关资料和操作流程。

    相关博客:
    超详细!DALL · E 文生图模型实践指南
    DALL·E 2 文生图模型实践指南


    1. 环境搭建和预训练模型准备

    环境搭建

    pip install diffusers transformers accelerate scipy safetensors
    
    • 1

    预训练模型下载

    关于 huggingface 网站总是崩溃的情况,找到一个解决办法,就是可以通过脚本来下载

    第一步:安装 huggingface_hub,使用命令 pip install huggingface_hub
    第二步:下载具体模型,使用命令 python model_download.py --repo_id model_id,其中,model_id 为要下载的模型,比如SD v2.1 版本的model_id可以是 stabilityai/stable-diffusion-2-1;SD v1.5 版本的model_id可以是 runwayml/stable-diffusion-v1-5. model_id 的查找方式是在huggingface 网站直接搜索需要的模型(如下图),得到的「模型来源/版本」的组合即为所需。

    在这里插入图片描述

    model_download.py文件来自这个链接

    # usage     : python model_download.py --repo_id repo_id
    # example   : python model_download.py --repo_id facebook/opt-350m
    import argparse
    import time
    import requests
    import json
    import os
    from huggingface_hub import snapshot_download
    import platform
    from tqdm import tqdm
    from urllib.request import urlretrieve
    
    
    def _log(_repo_id, _type, _msg):
        date1 = time.strftime('%Y-%m-%d %H:%M:%S')
        print(date1 + " " + _repo_id + " " + _type + " :" + _msg)
    
    
    def _download_model(_repo_id, _repo_type):
        if _repo_type == "model":
            _local_dir = 'dataroot/models/' + _repo_id
        else:
            _local_dir = 'dataroot/datasets/' + _repo_id
        try:
            if _check_Completed(_repo_id, _local_dir):
                return True, "check_Completed ok"
        except Exception as e:
            return False, "check_Complete exception," + str(e)
        _cache_dir = 'caches/' + _repo_id
    
        _local_dir_use_symlinks = True
        if platform.system().lower() == 'windows':
            _local_dir_use_symlinks = False
        try:
            if _repo_type == "model":
                snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                                  resume_download=True, max_workers=4)
            else:
                snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                                  resume_download=True, max_workers=4, repo_type="dataset")
        except Exception as e:
            error_msg = str(e)
            if ("401 Client Error" in error_msg):
                return True, error_msg
            else:
                return False, error_msg
        _removeHintFile(_local_dir)
        return True, ""
    
    
    def _writeHintFile(_local_dir):
        file_path = _local_dir + '/~incomplete.txt'
        if not os.path.exists(file_path):
            if not os.path.exists(_local_dir):
                os.makedirs(_local_dir)
            open(file_path, 'w').close()
    
    
    def _removeHintFile(_local_dir):
        file_path = _local_dir + '/~incomplete.txt'
        if os.path.exists(file_path):
            os.remove(file_path)
    
    
    def _check_Completed(_repo_id, _local_dir):
        _writeHintFile(_local_dir)
        url = 'https://huggingface.co/api/models/' + _repo_id
        response = requests.get(url)
        if response.status_code == 200:
            data = json.loads(response.text)
        else:
            return False
        for sibling in data["siblings"]:
            if not os.path.exists(_local_dir + "/" + sibling["rfilename"]):
                return False
        _removeHintFile(_local_dir)
        return True
    
    
    def download_model_retry(_repo_id, _repo_type):
        i = 0
        flag = False
        msg = ""
        while True:
            flag, msg = _download_model(_repo_id, _repo_type)
            if flag:
                _log(_repo_id, "success", msg)
                break
            else:
                _log(_repo_id, "fail", msg)
                if i > 1440:
                    msg = "retry over one day"
                    _log(_repo_id, "fail", msg)
                    break
                timeout = 60
                time.sleep(timeout)
                i = i + 1
                _log(_repo_id, "retry", str(i))
        return flag, msg
    
    
    def _fetchFileList(files):
        _files = []
        for file in files:
            if file['type'] == 'dir':
                filesUrl = 'https://e.aliendao.cn/' + file['path'] + '?json=true'
                response = requests.get(filesUrl)
                if response.status_code == 200:
                    data = json.loads(response.text)
                    for file1 in data['data']['files']:
                        if file1['type'] == 'dir':
                            filesUrl = 'https://e.aliendao.cn/' + \
                                file1['path'] + '?json=true'
                            response = requests.get(filesUrl)
                            if response.status_code == 200:
                                data = json.loads(response.text)
                                for file2 in data['data']['files']:
                                    _files.append(file2)
                        else:
                            _files.append(file1)
            else:
                if file['name'] != '.gitattributes':
                    _files.append(file)
        return _files
    
    
    def _download_file_resumable(url, save_path, i, j, chunk_size=1024*1024):
        headers = {}
        r = requests.get(url, headers=headers, stream=True, timeout=(20, 60))
        if r.status_code == 403:
            _log(url, "download", '下载资源发生了错误,请使用正确的token')
            return False
        bar_format = '{desc}{percentage:3.0f}%|{bar}|{n_fmt}M/{total_fmt}M [{elapsed}<{remaining}, {rate_fmt}]'
        _desc = str(i) + ' of ' + str(j) + '(' + save_path.split('/')[-1] + ')'
        total_length = int(r.headers.get('content-length'))
        if os.path.exists(save_path):
            temp_size = os.path.getsize(save_path)
        else:
            temp_size = 0
        retries = 0
        if temp_size >= total_length:
            return True
        # 小文件显示
        if total_length < chunk_size:
            with open(save_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=chunk_size):
                    if chunk:
                        f.write(chunk)
            with tqdm(total=1, desc=_desc, unit='MB', bar_format=bar_format) as pbar:
                pbar.update(1)
        else:
            headers['Range'] = f'bytes={temp_size}-{total_length}'
            r = requests.get(url, headers=headers, stream=True,
                             verify=False, timeout=(20, 60))
            data_size = round(total_length / 1024 / 1024)
            with open(save_path, 'ab') as fd:
                fd.seek(temp_size)
                initial = temp_size//chunk_size
                for chunk in tqdm(iterable=r.iter_content(chunk_size=chunk_size), initial=initial, total=data_size, desc=_desc, unit='MB', bar_format=bar_format):
                    if chunk:
                        temp_size += len(chunk)
                        fd.write(chunk)
                        fd.flush()
        return True
    
    
    def _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
        if _repo_type == "model":
            filesUrl = 'https://e.aliendao.cn/models/' + _repo_id + '?json=true'
        else:
            filesUrl = 'https://e.aliendao.cn/datasets/' + _repo_id + '?json=true'
        response = requests.get(filesUrl)
        if response.status_code != 200:
            _log(_repo_id, "mirror", str(response.status_code))
            return False
        data = json.loads(response.text)
        files = data['data']['files']
        for file in files:
            if file['name'] == '~incomplete.txt':
                _log(_repo_id, "mirror", 'downloading')
                return False
        files = _fetchFileList(files)
        i = 1
        for file in files:
            url = 'http://61.133.217.142:20800/download' + file['path']
            if _e:
                url = 'http://61.133.217.139:20800/download' + \
                    file['path'] + "?token=" + _token
            file_name = 'dataroot/' + file['path']
            if not os.path.exists(os.path.dirname(file_name)):
                os.makedirs(os.path.dirname(file_name))
            i = i + 1
            if not _download_file_resumable(url, file_name, i, len(files)):
                return False
        return True
    
    
    def download_model_from_mirror(_repo_id, _repo_type, _token, _e):
        if _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
            return
        else:
            #return download_model_retry(_repo_id, _repo_type)
            _log(_repo_id, "download", '下载资源发生了错误,请使用正确的token')
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--repo_id', default=None, type=str, required=True)
        parser.add_argument('--repo_type', default="model",
                            type=str, required=False)  # models,dataset
        # --mirror为从aliendao.cn镜像下载,如果aliendao.cn没有镜像,则会转到hf
        # 默认为True
        parser.add_argument('--mirror', action='store_true',
                            default=True, required=False)
        parser.add_argument('--token', default="", type=str, required=False)
        # --e为企业付费版
        parser.add_argument('--e', action='store_true',
                            default=False, required=False)
        args = parser.parse_args()
        if args.mirror:
            download_model_from_mirror(
                args.repo_id, args.repo_type, args.token, args.e)
        else:
            download_model_retry(args.repo_id, args.repo_type)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224

    2. 代码

    Stable Diffusion 完整推断流程如下(from https://huggingface.co/stabilityai/stable-diffusion-2-1):

    import torch
    from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
    
    model_id = "/dataroot/models/stabilityai/stable-diffusion-2-1"  # 预训练模型的下载路径
    
    # Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")
    
    prompt = "a photo of an astronaut riding a horse on mars"
    image = pipe(prompt).images[0]
        
    image.save("astronaut_rides_horse.png")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    参考文献

    1. https://aliendao.cn/model_download.py
    2. https://github.com/Stability-AI/stablediffusion
  • 相关阅读:
    消息队列缓存,以蓝牙消息服务为例
    Elasticsearch:使用 Open AI 和 Langchain 的 RAG - Retrieval Augmented Generation (四)
    通用的异常处理程序机制与处理返回值方案
    你真的了解黑客吗?
    c++-json(0)-rapidjson-read-write-实例代码
    基于Python+tkinter实现一个简易计算器桌面软件
    多臂老虎机
    Github: Github actions 自动化工作原理与多workflow创建和部署
    OpenStack架构详解
    [C++](19)AVL树插入,旋转,详细图解与代码
  • 原文地址:https://blog.csdn.net/qq_36332660/article/details/134426792