• 区分stable diffusion中的通道数与张量维度



    前言:通道数与张量形状都在数值3和4之间变换,容易混淆。

    1.通道数:

    1.1 channel = 3

    RGB 图像具有 3 个通道(红色、绿色和蓝色)。

    1.2 channel = 4

    Stable Diffusion has 4 latent channels。
    如何理解卷积神经网络中的通道(channel)

    2.张量形状

    2.1 3D 张量

    形状为 (C, H, W),其中 C 是通道数,H 是高度,W 是宽度。这适用于单个图像。

    2.2 4D 张量

    2.2.1 通常

    形状为 (B, C, H, W),其中 B 是批次大小,C 是通道数,H 是高度,W 是宽度。这适用于多个图像(例如,批量处理)。

    2.2.2 stable diffusion

    在img2img中,将image用vae编码并按照timestep加噪:

    		# This code copyed from diffusers.pipline_controlnet_img2img.py
            # 6. Prepare latent variables
            latents = self.prepare_latents(
                image,
                latent_timestep,
                batch_size,
                num_images_per_prompt,
                prompt_embeds.dtype,
                device,
                generator,
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    image的dim(维度)是3,而latents的dim为4。
    让我们先看text2img的prepare_latents函数:

        # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
        def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
            shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            if latents is None:
                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            else:
                latents = latents.to(device)
    
            # scale the initial noise by the standard deviation required by the scheduler
            latents = latents * self.scheduler.init_noise_sigma
            return latents
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    显然,shape已经规定了latents的dim(4)和排列顺序。
    在img2img中:

        # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
        def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
            if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
                raise ValueError(
                    f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
                )
    
            image = image.to(device=device, dtype=dtype)
    
            batch_size = batch_size * num_images_per_prompt
    
            if image.shape[1] == 4:
                init_latents = image
    
            else:
                if isinstance(generator, list) and len(generator) != batch_size:
                    raise ValueError(
                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                    )
    
                elif isinstance(generator, list):
                    init_latents = [
                        self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
                    ]
                    init_latents = torch.cat(init_latents, dim=0)
                else:
                    init_latents = self.vae.encode(image).latent_dist.sample(generator)
    
                init_latents = self.vae.config.scaling_factor * init_latents
    
            if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
                # expand init_latents for batch_size
                deprecation_message = (
                    f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
                    " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
                    " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
                    " your script to pass as many initial images as text prompts to suppress this warning."
                )
                deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
                additional_image_per_prompt = batch_size // init_latents.shape[0]
                
                init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
            elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
                raise ValueError(
                    f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
                )
            else:
                init_latents = torch.cat([init_latents], dim=0)
    
            shape = init_latents.shape
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    
            # get latents
            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
            latents = init_latents
    
            return latents
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    3.应用

    3.1 问题

    new_map = texture.permute(1, 2, 0)
    RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3
    
    • 1
    • 2

    该问题是张量形状的问题,跟通道数毫无关系。

    3.2 举例

    问:4D 张量:形状为 (B, C, H, W),其中C可以为3吗?
    答:4D 张量的形状为 (B,C,H,W),其中 C 表示通道数。通常情况下,C 可以为 3,这对应于 RGB 图像的三个颜色通道(红色、绿色和蓝色)。

    3.3 张量可以理解为多维可变数组

    print("sample:", sample.shape)
    print("sample:", sample[0].shape)
    print("sample:", sample[0][0].shape)
    
    • 1
    • 2
    • 3
    >>
    sample: torch.Size([10, 4, 96, 96])
    sample: torch.Size([4, 96, 96])
    sample: torch.Size([96, 96])
    
    • 1
    • 2
    • 3
    • 4

    由此可见,可以将张量形状为torch.size([10, 4, 96, 96])理解为一个4维可变数组。

  • 相关阅读:
    使用Code Chart绘制流程图
    Effective Java学习笔记---------序列化
    玉米社:SEM竞价推广转化成本高?做好细节转化率蹭蹭往上涨
    获取闲鱼已售商品的价格等信息
    React中的函数组件详解
    pojo之vo_dto_po的一些理解
    24届近3年河海大学自动化考研院校
    商场促销--策略模式
    Ubuntu20.04.6新系统没有wifi驱动(已解决)
    Docker安装canal、mysql进行简单测试与实现redis和mysql缓存一致性
  • 原文地址:https://blog.csdn.net/qq_44324007/article/details/138054676