- import os
- import subprocess
- import time
- from collections import defaultdict, deque
- import datetime
- import pickle
- from packaging import version
- from typing import Optional, List
-
- import torch
- import torch.distributed as dist
- from torch import Tensor
-
- # needed due to empty tensor bug in pytorch and torchvision 0.5
- import torchvision
- if version.parse(torchvision.__version__) < version.parse('0.7'):
- from torchvision.ops import _new_empty_tensor
- from torchvision.ops.misc import _output_size
这段代码导入了多个Python库和模块,其中包括Python标准库、PyTorch库以及一些辅助函数。以下是导入的库和模块的简要解释:
os:Python标准库的一部分,用于与操作系统进行交互,执行文件和目录操作等。
subprocess:Python标准库的一部分,用于启动和管理子进程,可以用于执行系统命令和外部程序。
time:Python标准库的一部分,用于处理时间相关的操作,如等待、测量时间间隔等。
collections:Python标准库的一部分,提供了一些额外的数据结构,如defaultdict和deque,用于更有效地组织和处理数据。
datetime:Python标准库的一部分,用于处理日期和时间。
pickle:Python标准库的一部分,用于序列化和反序列化Python对象。
typing:Python标准库的一部分,用于类型提示,使代码更具可读性和可维护性。
torch:PyTorch库,用于深度学习任务。这是主要的深度学习框架。
torch.distributed:PyTorch中用于分布式训练的模块,支持多台机器上的模型训练。
Tensor:PyTorch中的张量(tensor)数据类型。
torchvision:PyTorch官方的计算机视觉库,用于图像处理和视觉任务。
version:来自packaging模块的version类,用于比较和处理版本号。
这段代码的导入部分主要用于引入所需的库和模块,以便在后续的代码中使用它们执行各种任务,包括深度学习、数据处理、分布式计算等。其中,torch和torchvision是深度学习任务中最常用的库,用于创建和训练神经网络模型。
SmoothedValue类- class SmoothedValue(object):
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
-
- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
-
- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
-
- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median,
- avg=self.avg,
- global_avg=self.global_avg,
- max=self.max,
- value=self.value)
这是一个Python类SmoothedValue的初始化部分,该类用于跟踪一系列数值并提供对这些数值的平滑处理(滑动窗口或全局平均值)。以下是这个类的主要属性和方法的解释:
__init__(self, window_size=20, fmt=None):初始化方法,用于创建SmoothedValue类的实例。参数window_size指定用于计算平均值的滑动窗口的大小,默认为20。参数fmt是一个格式化字符串,用于定义以字符串形式表示SmoothedValue实例时的格式,默认为"{median:.4f} ({global_avg:.4f})"。
update(self, value, n=1):更新方法,用于添加一个新的数值到SmoothedValue中。参数value是要添加的数值,参数n表示数值的数量(默认为1)。这个方法用于不断更新跟踪的数值序列。
- class SmoothedValue(object):
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
这是SmoothedValue类的构造函数__init__和更新方法update的部分实现。
__init__(self, window_size=20, fmt=None):构造函数,用于初始化SmoothedValue对象的属性。参数window_size指定用于计算平均值的滑动窗口的大小,默认为20。参数fmt是一个格式化字符串,用于定义以字符串形式表示SmoothedValue实例时的格式,默认为"{median:.4f} ({global_avg:.4f})"。
self.deque:创建一个双端队列(deque),用于存储跟踪的数值序列。maxlen参数指定队列的最大长度,当队列超过这个长度时,最旧的元素将被自动删除。
self.total:初始化一个总和变量,用于跟踪所有添加数值的总和。
self.count:初始化一个计数变量,用于跟踪添加数值的数量。
self.fmt:初始化一个格式化字符串,用于定义以字符串形式表示SmoothedValue实例时的格式。
update(self, value, n=1):更新方法,用于添加一个新的数值到SmoothedValue中。参数value是要添加的数值,参数n表示数值的数量(默认为1)。
self.deque.append(value):将新的数值添加到双端队列中。
self.count += n:更新计数变量,记录添加的数值数量。
self.total += value * n:更新总和变量,将新的数值乘以数量n后添加到总和中,以跟踪所有数值的总和。
这两个方法的目的是在SmoothedValue对象中维护一个数值序列,并计算该序列的平均值和其他统计信息。update方法用于将新的数值添加到序列中,并更新总和和计数,以便后续计算平均值等统计信息。这些统计信息可以在深度学习训练过程中用于监控和记录模型的性能。
syncronize_between_processes- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
这是SmoothedValue类中的一个方法syncronize_between_processes,该方法用于在多个进程之间同步计数和总和。以下是这个方法的功能和步骤的解释:
if not is_dist_avail_and_initialized()::这是一个条件语句,用于检查分布式环境是否可用且已初始化。如果不是分布式环境,就不需要执行同步操作,所以直接返回。
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda'):创建一个包含两个数值的PyTorch张量,其中包括self.count和self.total。这两个数值分别表示计数和总和。
dist.barrier():在分布式环境中,这个函数用于同步所有进程,确保在执行后续操作之前,所有进程都已经完成了前面的工作。
dist.all_reduce(t):在分布式环境中,这个函数用于将所有进程中的torch.tensor的值相加并将结果广播给所有进程。这样可以确保所有进程都具有相同的计数和总和。
t = t.tolist():将PyTorch张量转换为Python列表,以便后续处理。
self.count = int(t[0]):将同步后的计数值设置为self.count属性。
self.total = t[1]:将同步后的总和值设置为self.total属性。
总之,syncronize_between_processes方法用于在分布式环境中同步计数和总和,以确保不同进程具有相同的数值统计信息。这对于分布式训练中的性能监控和日志记录非常有用。需要注意的是,这个方法不同步双端队列(deque)本身,只同步了计数和总和。
SmoothedValue类中的属性- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
这些是SmoothedValue类中的属性,用于计算并返回与数值序列相关的统计信息。以下是这些属性的解释:
median:返回数值序列的中位数。它通过首先将双端队列(deque)中的数值转换为PyTorch张量,然后计算该张量的中位数来实现。最后,使用item()方法将中位数值转换为Python标量。
avg:返回数值序列的平均值。与median类似,它首先将数值转换为PyTorch张量,然后计算该张量的平均值。
global_avg:返回数值序列的全局平均值,即所有数值的总和除以数值的数量。这个值可以用于跟踪整个序列的平均值。
max:返回数值序列中的最大值,即双端队列(deque)中的最大数值。
这些属性提供了一种方便的方式来获取与数值序列相关的统计信息,例如中位数、平均值、全局平均值和最大值。这些统计信息对于深度学习模型的性能监控和日志记录非常有用,可以帮助了解模型的训练进展和性能表现。
SmoothedValue类中的两个属性和一个特殊方法- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median,
- avg=self.avg,
- global_avg=self.global_avg,
- max=self.max,
- value=self.value)
这是SmoothedValue类中的两个属性和一个特殊方法的实现:
value:这是一个属性,返回数值序列中的最新值,即双端队列(deque)中的最后一个数值。
__str__:这是一个特殊方法,用于将SmoothedValue实例转换为字符串表示。在这个方法中,使用了初始化时提供的格式化字符串(fmt)来格式化输出字符串。在格式化字符串中,可以使用占位符(如{median}、{avg}等)来代表不同的统计信息,例如中位数、平均值等。通过调用这个方法,可以将SmoothedValue实例的统计信息以人类可读的方式输出或记录。
这两个部分的实现增强了SmoothedValue类的可用性,使其更容易与其他代码集成,以便在深度学习训练中跟踪和记录数值统计信息。特别是__str__方法提供了一种友好的方式来格式化和呈现这些统计信息。
all_gather函数- def all_gather(data):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
-
- # serialized to a Tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
-
- # obtain Tensor size of each rank
- local_size = torch.tensor([tensor.numel()], device="cuda")
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
- if local_size != max_size:
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
-
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
这是一个名为all_gather的函数,用于在分布式计算环境中执行数据的全局收集(all-gather)。这个函数可以用于收集任何可序列化(picklable)的数据,而不仅仅是张量。以下是这个函数的主要步骤和功能:
world_size = get_world_size():获取当前分布式环境中的总进程数。
if world_size == 1::如果只有一个进程,则无需执行全局收集,直接返回包含原始数据的列表。
buffer = pickle.dumps(data):将输入的数据data序列化为字节流,并存储在buffer中。
storage = torch.ByteStorage.from_buffer(buffer):创建一个PyTorch字节存储(ByteStorage)对象,用于存储序列化数据的字节。
tensor = torch.ByteTensor(storage).to("cuda"):将字节存储转换为PyTorch字节张量(ByteTensor)并将其移动到CUDA设备上。
local_size = torch.tensor([tensor.numel()], device="cuda"):创建一个包含当前进程张量大小的张量,并将其移到CUDA设备上。
size_list:创建一个包含每个进程张量大小的列表,其中每个元素都是一个包含零的张量,用于接收其他进程的张量大小信息。
dist.all_gather(size_list, local_size):使用分布式通信操作dist.all_gather,将每个进程的张量大小信息收集到size_list中。
size_list:将size_list中的每个张量大小转换为整数,并找到最大的大小。
tensor_list:创建一个用于接收其他进程张量的列表,每个元素都是一个具有最大大小的空张量。
如果当前进程的张量大小与最大大小不同,将填充张量(padding)添加到当前进程的张量,以使它们具有相同的大小。
使用dist.all_gather操作,将所有进程的张量收集到tensor_list中。
data_list:通过将每个进程的张量从tensor_list中提取并反序列化为原始数据,创建一个包含所有进程数据的列表。
最终,函数返回一个列表,其中包含了从每个进程收集到的原始数据。这个函数的主要目的是在分布式计算环境中方便地执行数据的全局收集,以便在不同进程之间共享信息。这对于在深度学习训练中收集和合并模型参数、梯度等信息非常有用。
- def all_gather(data):
- """
- Run all_gather on arbitrary picklable data (not necessarily tensors)
- Args:
- data: any picklable object
- Returns:
- list[data]: list of data gathered from each rank
- """
- world_size = get_world_size()
- if world_size == 1:
- return [data]
这是一个all_gather函数的部分实现,用于在分布式计算环境中执行数据的全局收集。以下是该部分实现的功能和步骤的解释:
world_size = get_world_size():获取当前分布式环境中的总进程数。
if world_size == 1::如果只有一个进程,则无需执行全局收集,直接返回包含原始数据的列表。这个条件检查确保只有在多进程环境中才执行全局收集,以避免不必要的通信开销。
总之,这部分实现是all_gather函数的一部分,用于处理特殊情况,即在单进程环境中,不需要进行全局收集,因此直接返回包含原始数据的列表。这有助于提高代码的效率和可读性。完整的all_gather函数会在多进程环境中执行全局收集。
data对象序列化为一个PyTorch字节张量- # serialized to a Tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
这部分代码用于将输入的data对象序列化为一个PyTorch字节张量,并将该张量移动到CUDA设备上。以下是这些步骤的解释:
buffer = pickle.dumps(data):使用Python的pickle模块将输入的data对象序列化为一个字节流(byte stream),并将该字节流存储在buffer变量中。序列化是将数据对象转换为字节表示的过程,以便将其保存到文件或传输到其他进程。
storage = torch.ByteStorage.from_buffer(buffer):创建一个PyTorch字节存储(ByteStorage)对象,该对象从先前创建的字节流buffer中初始化。这个字节存储对象用于存储字节数据,并将在后续步骤中用于创建字节张量。
tensor = torch.ByteTensor(storage).to("cuda"):将字节存储对象storage转换为PyTorch字节张量(ByteTensor),然后使用.to("cuda")将该张量移动到CUDA设备上。这意味着该字节张量将存储在GPU上,以便在GPU上进行后续的操作。
总之,这部分代码的目标是将输入的数据对象序列化为一个GPU上的PyTorch字节张量,以便将其传递给其他分布式进程。在分布式计算中,数据通常需要在不同的GPU上共享,因此需要将数据对象转换为GPU上的张量。
- # obtain Tensor size of each rank
- local_size = torch.tensor([tensor.numel()], device="cuda")
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
这部分代码执行以下操作:
local_size = torch.tensor([tensor.numel()], device="cuda"):首先,它创建一个包含当前进程张量大小的PyTorch张量,并将其移到CUDA设备上。tensor.numel()用于获取张量中的元素数量,这里是张量的大小。
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]:接下来,它创建一个包含与分布式环境中的每个进程相关的PyTorch张量的列表。每个张量都初始化为0,并且都存储在CUDA设备上。这些张量用于接收其他进程的张量大小信息。
dist.all_gather(size_list, local_size):使用PyTorch的分布式通信操作dist.all_gather,它将每个进程的local_size(即当前进程的张量大小)收集到size_list中。这会导致size_list包含了来自所有进程的张量大小信息。
size_list = [int(size.item()) for size in size_list]:最后,它将size_list中的每个张量从PyTorch张量转换为Python整数,并将这些整数存储在size_list中。这将为每个进程提供了一个整数值,表示其他进程的张量大小,以及一个整数值max_size,表示所有进程中最大的张量大小。
总之,这段代码的目标是获取分布式环境中每个进程的张量大小,并找到所有进程中最大的张量大小。这对于后续的数据收集和分配非常有用,因为它确保了张量的大小一致性,以便进行通信。
tensor_list- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
- if local_size != max_size:
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
这部分代码执行以下操作:
tensor_list = []:首先,它创建一个空列表tensor_list,用于存储从每个进程收集到的张量。列表的长度等于进程总数,每个元素都是一个具有相同最大大小的空PyTorch张量。这里的size_list包含了每个进程的张量大小信息。
for _ in size_list::然后,它使用size_list中的每个进程的张量大小信息,迭代地创建一个空的PyTorch张量,并将其添加到tensor_list中。这确保了tensor_list中的每个元素都是一个具有相同大小的张量。
if local_size != max_size::接下来,它检查当前进程的张量大小是否与最大大小max_size不同。
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda"):如果当前进程的张量大小不等于最大大小,那么它创建一个名为padding的空PyTorch张量,其大小等于max_size - local_size,这是为了将当前进程的张量填充到与其他进程相同的大小。
tensor = torch.cat((tensor, padding), dim=0):然后,它使用torch.cat函数将当前进程的张量tensor与padding张量连接起来,以便它们具有相同的大小。这是为了确保所有进程的张量都具有相同的大小,以便进行全局收集。
dist.all_gather(tensor_list, tensor):最后,使用PyTorch的分布式通信操作dist.all_gather,它将每个进程的张量tensor收集到tensor_list中。由于所有张量的大小现在相同,可以安全地进行收集。
总之,这段代码的目标是创建一个用于接收从所有进程收集到的张量的列表tensor_list,并确保这些张量具有相同的大小,以便在进行全局收集时能够正确工作。这是由于PyTorch的dist.all_gather不支持收集具有不同形状的张量,因此需要进行必要的填充以使它们具有相同的大小。
tensor_list转换回原始数据对象的列表data_list- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
这部分代码的目标是将从所有进程收集到的数据张量列表tensor_list转换回原始数据对象的列表data_list。以下是它的执行步骤:
data_list = []:首先,它创建一个空列表data_list,用于存储从张量恢复的原始数据对象。
for size, tensor in zip(size_list, tensor_list)::然后,它迭代size_list和tensor_list中的元素,其中size是张量的大小,tensor是从其他进程收集到的数据张量。
buffer = tensor.cpu().numpy().tobytes()[:size]:对于每个张量,它首先使用.cpu()将张量从CUDA设备移到CPU上,然后使用.numpy()将其转换为NumPy数组,最后使用.tobytes()将NumPy数组转换为字节表示。由于在之前的填充步骤中,我们已经确保了每个张量的大小与原始数据相符,因此只需提取前size个字节,以避免多余的填充数据。
data_list.append(pickle.loads(buffer)):最后,它使用pickle.loads()将字节流buffer反序列化为原始数据对象,并将其添加到data_list中。
最终,函数返回data_list,其中包含了从所有进程收集到的原始数据对象的列表。
总之,这段代码的目标是将从其他进程收集到的数据张量转换回原始数据对象,并将这些对象存储在data_list中,以便在分布式计算中使用。这是在分布式环境中传递和共享数据的一种方法。
- def reduce_dict(input_dict, average=True):
- """
- Args:
- input_dict (dict): all the values will be reduced
- average (bool): whether to do average or sum
- Reduce the values in the dictionary from all processes so that all processes
- have the averaged results. Returns a dict with the same fields as
- input_dict, after reduction.
- """
- world_size = get_world_size()
- if world_size < 2:
- return input_dict
- with torch.no_grad():
- names = []
- values = []
- # sort the keys so that they are consistent across processes
- for k in sorted(input_dict.keys()):
- names.append(k)
- values.append(input_dict[k])
- values = torch.stack(values, dim=0)
- dist.all_reduce(values)
- if average:
- values /= world_size
- reduced_dict = {k: v for k, v in zip(names, values)}
- return reduced_dict
这段代码的作用是将一个字典中的值从所有进程中进行归约,以便所有进程都具有归约后的结果。具体来说,它执行以下操作:
world_size = get_world_size():首先,它获取当前分布式环境中的进程总数,以便确定是否需要进行归约。如果只有一个进程(即单进程模式),则不执行任何归约操作,直接返回输入的字典。
with torch.no_grad()::这是一个上下文管理器,用于确保在此代码块中不会创建梯度计算图。这对于执行纯粹的数据操作非常有用。
names = [] 和 values = []:创建两个空列表,用于分别存储字典中的键和对应的值。
for k in sorted(input_dict.keys())::迭代字典中的键,通过sorted函数对键进行排序,以确保它们在所有进程中的顺序一致。
names.append(k) 和 values.append(input_dict[k]):将每个键存储在names列表中,将对应的值存储在values列表中。
values = torch.stack(values, dim=0):将values列表中的值堆叠成一个张量,其中每一行对应一个进程的值。dim=0表示在张量的第一个维度上堆叠。
dist.all_reduce(values):使用PyTorch的分布式通信操作dist.all_reduce,将所有进程中的值进行归约。这意味着每个进程的值将与其他进程的值相加,从而在所有进程中获得了总和。
if average::如果average参数为True,则执行下面的操作。如果average为False,则跳过这一步,直接返回总和的值。
values /= world_size:将总和的值除以进程总数,以获得平均值。这是通过将值张量除以world_size来实现的。
reduced_dict = {k: v for k, v in zip(names, values)}:将键和归约后的值重新组成一个字典,并将其存储在reduced_dict中。
最终,函数返回reduced_dict,其中包含了从所有进程中归约得到的字典,这些字典的值已经被平均(或总和)处理,以便在分布式计算中使用。这是一种在分布式环境中合并和同步结果的方法。
MetricLogger的类,用于记录和打印指标数据- class MetricLogger(object):
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
-
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
-
- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError("'{}' object has no attribute '{}'".format(
- type(self).__name__, attr))
-
- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append(
- "{}: {}".format(name, str(meter))
- )
- return self.delimiter.join(loss_str)
-
- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
-
- def add_meter(self, name, meter):
- self.meters[name] = meter
-
- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ''
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt='{avg:.4f}')a
- data_time = SmoothedValue(fmt='{avg:.4f}')
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
- if torch.cuda.is_available():
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}',
- 'max mem: {memory:.0f}'
- ])
- else:
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}'
- ])
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / MB))
- else:
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time)))
- i += 1
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('{} Total time: {} ({:.4f} s / it)'.format(
- header, total_time_str, total_time / len(iterable)))
该类的主要作用是方便记录和打印训练过程中的指标数据,特别适用于监视模型训练的进展和性能。
- class MetricLogger(object):
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
__init__(self, delimiter="\t"):这是类的构造函数,用于初始化一个MetricLogger对象。它接受一个可选的参数delimiter,用于指定打印指标数据时的分隔符,默认为制表符("\t")。
self.meters:这是一个字典,用于存储指标数据。字典的键是指标的名称,而值是与每个指标相关联的SmoothedValue对象,该对象用于平滑和跟踪指标的值。
self.delimiter:这是一个属性,表示在打印指标数据时使用的分隔符,默认为制表符("\t")。
该类的主要目的是为了方便记录和管理模型训练过程中的各种指标,例如损失、准确率、学习率等。通过使用MetricLogger,可以更容易地跟踪和可视化这些指标的变化。通常,训练循环中的代码会在每个训练步骤或周期中调用MetricLogger的方法来更新指标数据,并在需要时打印这些指标的值。
MetricLogger类中的一个重要方法,用于更新指标数据- def update(self, **kwargs):
- for k, v in kwargs.items():
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
这是MetricLogger类中的一个重要方法,用于更新指标数据。它接受一个关键字参数kwargs,其中每个键值对表示一个指标的名称和相应的值。方法会遍历kwargs中的每个键值对,将指标的值添加到self.meters字典中的相应SmoothedValue对象中。
在更新指标值之前,方法会检查指标的值是否是torch.Tensor类型,如果是,则将其转换为标量(float或int)。这是因为通常指标的值以张量的形式存储,但在记录和显示时,通常希望将其表示为标量值。
最后,方法会调用self.meters[k].update(v),将值添加到相应指标的SmoothedValue对象中,用于平滑和跟踪指标的值。通过多次调用update方法,可以持续跟踪指标的变化,并计算平均值、中位数等统计信息。
MetricLogger类中的一个特殊方法__getattr__- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError("'{}' object has no attribute '{}'".format(
- type(self).__name__, attr))
这是MetricLogger类中的一个特殊方法__getattr__,它用于处理对象的属性访问。具体来说,当你尝试访问MetricLogger对象的属性时,会调用__getattr__方法来确定应该返回什么值。
方法的主要作用是:
attr是否存在于self.meters字典中。如果存在,表示用户希望访问某个指标的值,于是返回该指标的SmoothedValue对象。attr不存在于self.meters字典中,继续检查是否存在于对象的__dict__属性中。这是因为除了指标之外,MetricLogger对象还可以有其他自定义属性。attr既不是指标也不是自定义属性,则抛出AttributeError异常,表示对象没有这个属性。这个方法的设计使得可以通过点号.来访问MetricLogger对象的指标,例如logger.loss,其中loss是指标的名称。这样可以方便地获取和记录指标值,而不需要直接访问self.meters字典。
__str__ 方法- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append(
- "{}: {}".format(name, str(meter))
- )
- return self.delimiter.join(loss_str)
__str__ 方法是 Python 中的特殊方法,用于返回对象的字符串表示形式。在 MetricLogger 类中,__str__ 方法被重写以自定义对象的字符串表示,以便用户可以通过 print 函数轻松地查看对象的状态。
具体来说,__str__ 方法执行以下操作:
创建一个空列表 loss_str 用于存储每个指标的字符串表示。
遍历 self.meters 字典中的每个指标(以指标名称为键,以 SmoothedValue 对象为值)。
对每个指标,使用 str(meter) 来获取 SmoothedValue 对象的字符串表示,该字符串表示包括指标的平均值、中位数等统计信息。
将指标的名称和相应的字符串表示形式组合成一个字符串,并将该字符串添加到 loss_str 列表中。
最后,使用 self.delimiter 连接 loss_str 列表中的所有字符串,形成一个以指定分隔符分隔的单个字符串,表示所有指标的状态。
通过这个方法,你可以通过 print(logger) 来打印 MetricLogger 对象的状态,从而方便地查看所有指标的值。这有助于监控训练过程中的指标变化和性能。
synchronize_between_processes 方法- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
MetricLogger 类中的 synchronize_between_processes 方法用于在多个进程之间同步指标数据。在分布式训练环境中,不同进程可能会计算不同部分的指标,因此需要将这些指标同步以获取全局统计信息。
具体来说,这个方法执行以下操作:
遍历 self.meters 字典中的每个指标名称和相应的 SmoothedValue 对象。
对每个 SmoothedValue 对象调用其自身的 synchronize_between_processes 方法,以确保该指标在所有进程之间同步。
synchronize_between_processes 方法在内部使用了 torch.distributed 库来执行同步操作。如果当前环境不是分布式训练(即没有启用多个进程),则方法不执行任何操作。
这个方法的主要目的是确保在分布式训练中,所有进程都具有相同的指标数据,以便能够计算全局统计信息。这对于监控和分析训练进程中的性能和指标非常有用。
add_meter 方法- def add_meter(self, name, meter):
- self.meters[name] = meter
add_meter 方法用于向 MetricLogger 对象中添加新的指标。它接受两个参数:
name:要添加的指标的名称,作为字符串。meter:一个 SmoothedValue 对象,用于跟踪和记录指标值。通过调用 add_meter 方法,您可以将新的指标添加到 MetricLogger 中,然后可以使用这些指标来跟踪和记录训练过程中的各种性能和损失指标。这对于监视模型的性能和训练进程中的各种指标非常有用。
log_every 方法用于在迭代训练中定期记录和打印训练进程的信息。- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ''
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt='{avg:.4f}')a
- data_time = SmoothedValue(fmt='{avg:.4f}')
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
- if torch.cuda.is_available():
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}',
- 'max mem: {memory:.0f}'
- ])
- else:
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}'
- ])
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / MB))
- else:
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time)))
- i += 1
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('{} Total time: {} ({:.4f} s / it)'.format(
- header, total_time_str, total_time / len(iterable)))
这个方法的主要目的是提供一个可视化和实时监控训练进程的方式,以便在训练模型时了解性能和损失指标的变化。
- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ''
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt='{avg:.4f}')a
- data_time = SmoothedValue(fmt='{avg:.4f}')
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
- if torch.cuda.is_available():
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}',
- 'max mem: {memory:.0f}'
- ])
- else:
- log_msg = self.delimiter.join([
- header,
- '[{0' + space_fmt + '}/{1}]',
- 'eta: {eta}',
- '{meters}',
- 'time: {time}',
- 'data: {data}'
- ])
在 log_every 方法中,它会遍历 iterable,这是一个迭代器,通常用于迭代训练数据。以下是该方法的主要步骤:
i,如果没有提供 header,则将其设置为空字符串。start_time。iter_time 和 data_time)。end。iterable 的长度确定输出格式中迭代次数的显示宽度。log_msg)。接下来,进入迭代循环:
对于 iterable 中的每个对象,执行以下操作:
data_time,记录从上一次迭代到当前迭代的时间。yield obj 从迭代器中获取下一个对象,并将其返回。yield 关键字用于生成器函数,可以暂停函数的执行并返回一个值,直到下一次迭代被调用。iter_time,记录从上一次迭代到当前迭代的时间。检查是否达到了指定的 print_freq 或是否已经遍历完了 iterable 中的所有对象。如果是,执行以下操作:
eta_seconds):这是平均每次迭代花费的时间乘以剩余迭代次数。log_msg,将当前迭代次数、总迭代次数、估计的剩余时间、平滑值统计信息以及时间和数据加载时间插入到消息中。增加迭代计数器 i,更新 end 以记录当前时间,然后继续下一次迭代。
最后,计算总的训练时间,并打印总时间以及每次迭代的平均时间。
这个方法用于实时监控和记录训练进程中的信息,包括迭代次数、剩余时间、性能指标等,以便及时调整训练策略和分析模型的性能。
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0 or i == len(iterable) - 1:
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time),
- memory=torch.cuda.max_memory_allocated() / MB))
- else:
- print(log_msg.format(
- i, len(iterable), eta=eta_string,
- meters=str(self),
- time=str(iter_time), data=str(data_time)))
- i += 1
- end = time.time()
- total_time = time.time() - start_time
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('{} Total time: {} ({:.4f} s / it)'.format(
- header, total_time_str, total_time / len(iterable)))
这部分代码是 log_every 方法的主要迭代循环,它处理了每个迭代的逻辑和日志记录。让我为您解释其中的关键部分:
MB = 1024.0 * 1024.0:这是用于将字节转换为兆字节(MB)的常数,以便在日志消息中显示内存使用量。
for obj in iterable::这是迭代 iterable 中的对象的开始。iterable 可能是用于训练的数据加载器。
data_time.update(time.time() - end):这行代码记录了数据加载时间,它计算从上一次迭代到当前迭代的时间差并更新 data_time 的平滑值。
yield obj:这行代码从迭代器中获取下一个对象,并将其返回。使用 yield 关键字可以将当前函数变成一个生成器,它会在 yield 处暂停执行并将值传递给调用方。
iter_time.update(time.time() - end):这行代码记录了整个迭代的时间,类似于上一行,它计算了从上一次迭代到当前迭代的时间差,并更新了 iter_time 的平滑值。
if i % print_freq == 0 or i == len(iterable) - 1::这是一个条件语句,它检查是否达到了指定的 print_freq 或是否已经遍历完了 iterable 中的所有对象。如果是这两种情况之一,就会执行以下操作,用于记录日志信息:
eta_seconds = iter_time.global_avg * (len(iterable) - i):这行代码计算了估计的剩余时间(以秒为单位),这是平均每次迭代花费的时间乘以剩余迭代次数。
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))):这行代码将估计的剩余时间转换为可读的时间格式(天、小时、分钟和秒)。
日志消息的构建:这部分根据是否有可用的 CUDA 设备来构建日志消息 log_msg,并将迭代次数、总迭代次数、估计的剩余时间、平滑值统计信息以及时间和数据加载时间插入到消息中。如果有可用的 CUDA 设备,还会包括最大内存使用量。
打印日志消息:最后,根据所构建的日志消息 log_msg 打印日志信息。
i += 1:迭代计数器 i 增加 1,用于跟踪当前迭代次数。
end = time.time():更新 end 以记录当前时间,为下一次迭代做准备。
最后,计算并打印总的训练时间以及每次迭代的平均时间。这些信息对于了解训练进度和性能非常有用。
这个 log_every 方法的主要目的是实时记录和显示训练进度和性能指标,以便让用户能够及时了解模型训练的状态。
- def get_sha():
- cwd = os.path.dirname(os.path.abspath(__file__))
-
- def _run(command):
- return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
- sha = 'N/A'
- diff = "clean"
- branch = 'N/A'
- try:
- sha = _run(['git', 'rev-parse', 'HEAD'])
- subprocess.check_output(['git', 'diff'], cwd=cwd)
- diff = _run(['git', 'diff-index', 'HEAD'])
- diff = "has uncommited changes" if diff else "clean"
- branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
- except Exception:
- pass
- message = f"sha: {sha}, status: {diff}, branch: {branch}"
- return message
这段代码用于获取当前代码库的 Git 信息,包括提交的 SHA(commit hash)、工作目录的状态以及当前的分支。让我解释一下这段代码的关键部分:
cwd = os.path.dirname(os.path.abspath(__file__)):这一行获取了当前脚本文件的绝对路径,并使用 os.path.dirname 获取了它的父目录路径。这个路径将用于设置 Git 命令的工作目录。
_run(command):这是一个内部函数,用于运行命令并返回其输出。它接受一个命令作为参数,运行该命令,并返回命令的输出。在这里,它被用来运行 Git 命令。
sha = _run(['git', 'rev-parse', 'HEAD']):这一行运行 Git 命令 git rev-parse HEAD,以获取当前代码库的最新提交的 SHA(commit hash)。
subprocess.check_output(['git', 'diff'], cwd=cwd):这一行运行 Git 命令 git diff,以检查工作目录中是否有未提交的更改。如果有未提交的更改,会引发异常,否则不会有异常。
diff = _run(['git', 'diff-index', 'HEAD']):这一行运行 Git 命令 git diff-index HEAD,以获取有关工作目录更改的详细信息。如果有未提交的更改,diff 将包含有关这些更改的信息,否则它将是 "clean",表示工作目录是干净的。
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']):这一行运行 Git 命令 git rev-parse --abbrev-ref HEAD,以获取当前的 Git 分支名称。
message = f"sha: {sha}, status: {diff}, branch: {branch}":这一行构建了一个包含 SHA、工作目录状态和分支信息的字符串。
最后,函数返回了包含 Git 信息的消息字符串。
这个函数对于在代码中记录版本信息以及调试和问题排查时非常有用。它会返回包含 Git 信息的消息,以便您可以了解当前代码的状态和版本。
collate_fn函数- #如何把图片大小不一样的图片输入到网络中去
- def collate_fn(batch):
- batch = list(zip(*batch))
- batch[0] = nested_tensor_from_tensor_list(batch[0])#batch[0]是image,batch[1]是信息{boxes,labels,image_id,iscrowd,orig_size,size},可以点到指定类型看一下
- return tuple(batch)
这段代码中的collate_fn函数用于将具有不同大小的图像和与之相关的信息组合成一个批次(batch),以便输入到神经网络中。这是在使用 PyTorch 中的 DataLoader 来加载数据时常见的操作,特别是在处理目标检测等任务时,由于不同图像的大小不同,需要进行一些预处理以使它们具有相同的尺寸或格式。
batch是一个包含批次中的多个样本的列表,每个样本包含两个元素,第一个元素是图像(可能是不同大小的),第二个元素是与图像相关的信息(如边界框、标签等)。
batch = list(zip(*batch)):这行代码将批次中的样本重新排列,以便分别获取图像和信息,并将它们放入两个不同的列表中。现在batch[0]包含所有图像,batch[1]包含所有信息。
batch[0] = nested_tensor_from_tensor_list(batch[0]):这行代码使用函数 nested_tensor_from_tensor_list将图像列表转换为一个"嵌套张量",这个嵌套张量允许处理不同大小的图像。这是一种常见的技巧,通常涉及将不同大小的图像调整为相同的尺寸,然后堆叠它们以创建一个包含不同图像的批次。
最后,return tuple(batch)将组合后的批次返回。
max_by_axis 函数- def _max_by_axis(the_list):
- # type: (List[List[int]]) -> List[int]
- maxes = the_list[0]
- for sublist in the_list[1:]:
- for index, item in enumerate(sublist):
- maxes[index] = max(maxes[index], item)
- return maxes
这段代码定义了一个名为 _max_by_axis 的函数,该函数接受一个包含列表的列表作为输入,然后返回一个包含每个列(轴)的最大值的列表。这个函数的目的是找到输入列表中的每个列的最大值。
the_list 是一个包含列表的列表,每个内部列表表示一个轴。
maxes 初始化为第一个内部列表 the_list[0]。
然后,函数遍历 the_list 中的每个内部列表,通过比较每个内部列表的元素与 maxes 中相应位置的元素,来更新 maxes 中的值。如果内部列表的元素比 maxes 中对应位置的元素更大,就用内部列表的元素替换 maxes 中的值。
最后,函数返回 maxes,其中包含了每个列的最大值。
这个函数的主要作用是在给定多个列表的情况下,找到每个列的最大值。这对于处理表格数据或类似数据的任务很有用,因为它可以帮助你找到每个属性或特征的最大值。
NestedTensor 类- #NestedTensor的定义
- class NestedTensor(object):
- def __init__(self, tensors, mask: Optional[Tensor]): #初始化
- self.tensors = tensors
- self.mask = mask #掩码 在这里可查看一个例子(见PPT2)
-
- def to(self, device):
- # type: (Device) -> NestedTensor # noqa
- cast_tensor = self.tensors.to(device)
- mask = self.mask
- if mask is not None:
- assert mask is not None
- cast_mask = mask.to(device)
- else:
- cast_mask = None
- return NestedTensor(cast_tensor, cast_mask)
-
- def decompose(self):
- return self.tensors, self.mask
-
- def __repr__(self):
- return str(self.tensors)
这段代码定义了一个名为 NestedTensor 的类,用于表示嵌套的张量(Nested Tensor)。嵌套的张量通常是由一个主张量(例如图像)和一个与之关联的掩码张量组成,掩码张量用于指示主张量中的每个位置是否有效。
这个类的设计主要用于处理具有不同形状或分辨率的图像数据,其中主张量表示图像,而掩码张量用于处理不同位置的信息。例如,在分割任务中,掩码张量可以用于表示每个像素是否属于对象区域。
- class NestedTensor(object):
- def __init__(self, tensors, mask: Optional[Tensor]): #初始化
- self.tensors = tensors
- self.mask = mask #掩码 在这里可查看一个例子(见PPT2)
这个代码定义了一个名为 NestedTensor 的类。这个类的目的是将两个元素组合成一个对象:
tensors: 这是一个张量(通常是 PyTorch 张量),用于存储主要数据。这个张量可以包含任何你想要存储的数据,比如图像数据。
mask: 这是一个可选的张量,通常也是 PyTorch 张量,用于表示一个掩码或者标志位。这个掩码张量的形状通常与主张量相同,但它的值通常是二进制的,用于指示主张量中的哪些元素是有效的,哪些是无效的。例如,在图像处理中,可以使用掩码来标识图像中的对象区域。
这个类的构造函数将这两个元素作为参数传递,并将它们存储在 self.tensors 和 self.mask 中,以便后续操作使用。这种方式可以让你将主要数据和相关的掩码或标志位组合在一起,以便更容易地处理和传递这些信息。
to 方法- def to(self, device):
- # type: (Device) -> NestedTensor # noqa
- cast_tensor = self.tensors.to(device)
- mask = self.mask
- if mask is not None:
- assert mask is not None
- cast_mask = mask.to(device)
- else:
- cast_mask = None
- return NestedTensor(cast_tensor, cast_mask)
这个 to 方法用于将 NestedTensor 对象中的数据(主要张量和掩码张量)转移到指定的设备(如GPU或CPU)。它接受一个 device 参数,该参数指定了目标设备。
具体的操作如下:
cast_tensor = self.tensors.to(device): 这一行将主要张量 self.tensors 移动到指定的设备上。这是通过调用 PyTorch 的 .to(device) 方法来实现的,其中 device 是目标设备,可以是 'cuda'(GPU)或 'cpu'(CPU)等。
mask = self.mask: 这一行获取掩码张量,如果存在的话。
如果掩码张量存在,那么 cast_mask = mask.to(device) 将掩码张量移动到相同的目标设备上。
最后,通过 return NestedTensor(cast_tensor, cast_mask) 返回一个新的 NestedTensor 对象,其中包含移动后的主要张量和掩码张量。
这个方法的作用是使 NestedTensor 对象中的数据与目标设备兼容,以便后续的计算可以在指定设备上进行。
NestedTensor 对象的内部数据- def decompose(self):
- return self.tensors, self.mask
-
- def __repr__(self):
- return str(self.tensors)
这两个方法用于返回 NestedTensor 对象的内部数据。
decompose(self): 此方法返回一个包含两个元素的元组,第一个元素是主要张量 self.tensors,第二个元素是掩码张量 self.mask。这允许在需要时直接访问这两个部分。
__repr__(self): 此方法返回一个字符串表示,通常用于在打印 NestedTensor 对象时显示其内容。在这里,它返回主要张量的字符串表示,这可以让您在打印对象时查看主要张量的值。
这两个方法提供了一种方式来访问 NestedTensor 对象的内部数据,以及在打印对象时以友好的方式显示其内容。
- def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
- # TODO make this more general
- if tensor_list[0].ndim == 3:
- if torchvision._is_tracing():
- # nested_tensor_from_tensor_list() does not export well to ONNX
- # call _onnx_nested_tensor_from_tensor_list() instead
- return _onnx_nested_tensor_from_tensor_list(tensor_list)
-
- # TODO make it support different-sized images
- max_size = _max_by_axis([list(img.shape) for img in tensor_list]) #首先会获得最大的区域,即宽和高
- # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
- batch_shape = [len(tensor_list)] + max_size #len(tensor_list)=2
- b, c, h, w = batch_shape
- dtype = tensor_list[0].dtype
- device = tensor_list[0].device
- tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
- mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
- for img, pad_img, m in zip(tensor_list, tensor, mask):
- pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) #将img的数据传给pad_img
- m[: img.shape[1], :img.shape[2]] = False #真实有值的地方会设置为True,没值的地方会设置为false
- else:
- raise ValueError('not supported')
- return NestedTensor(tensor, mask)
这个函数的目的是将一个张量列表转换为 NestedTensor 对象,其中 NestedTensor 是一个包含主要张量和掩码张量的对象,用于处理不同大小的张量(例如图像)。
这个函数的关键是将不同大小的张量合并成一个统一大小的 NestedTensor,以便将它们传递给神经网络等任务。这种处理方式对于处理不同尺寸的输入数据非常有用。
- def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
- # TODO make this more general
- if tensor_list[0].ndim == 3:
- if torchvision._is_tracing():
- # nested_tensor_from_tensor_list() does not export well to ONNX
- # call _onnx_nested_tensor_from_tensor_list() instead
- return _onnx_nested_tensor_from_tensor_list(tensor_list)
这段代码首先检查输入张量列表中第一个张量的维度数是否为3,以确定输入张量是否为3D张量。如果是3D张量,则继续进行处理。否则,可能需要根据情况进行通用化处理(TODO表示待完成)。
在处理3D张量的情况下,代码检查当前是否处于ONNX追踪模式(tracing mode)。ONNX是一种用于导出模型的格式。如果当前代码处于ONNX追踪模式,那么它会调用 _onnx_nested_tensor_from_tensor_list(tensor_list) 函数来创建 NestedTensor 对象。这是因为在ONNX追踪模式下,nested_tensor_from_tensor_list() 函数的行为可能无法正确导出到ONNX格式。
总之,这段代码是为了确保在ONNX追踪模式下不会出现问题,并且需要更通用的处理方式(TODO)以应对其他情况。
batch_shape)和一个掩码张量(mask),以创建一个 NestedTensor 对象- # TODO make it support different-sized images
- max_size = _max_by_axis([list(img.shape) for img in tensor_list]) #首先会获得最大的区域,即宽和高
- # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
- batch_shape = [len(tensor_list)] + max_size #len(tensor_list)=2
- b, c, h, w = batch_shape
- dtype = tensor_list[0].dtype
- device = tensor_list[0].device
- tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
- mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
- for img, pad_img, m in zip(tensor_list, tensor, mask):
- pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) #将img的数据传给pad_img
- m[: img.shape[1], :img.shape[2]] = False #真实有值的地方会设置为True,没值的地方会设置为false
- else:
- raise ValueError('not supported')
- return NestedTensor(tensor, mask)
这段代码的主要目的是创建一个具有不同大小的图像的批次,并将这些图像放置在一个 NestedTensor 中。以下是逐行解释代码的功能:
max_size = _max_by_axis([list(img.shape) for img in tensor_list]):
tensor_list 中所有图像的最大尺寸,包括宽度和高度。_max_by_axis 函数用于找到每个轴(维度)的最大尺寸。batch_shape = [len(tensor_list)] + max_size:
len(tensor_list) 给出了图像列表中图像的数量,即批次大小。max_size 包含了最大宽度和最大高度,将其添加到批次大小后,得到 batch_shape,它表示了批次中每个图像的大小。b, c, h, w = batch_shape:
batch_shape 中的值解包到变量 b(批次大小)、c(通道数,一般为图像的通道数,如3表示RGB图像)、h(高度)和w(宽度)中。dtype = tensor_list[0].dtype:
torch.float32 或 torch.uint8。device = tensor_list[0].device:
tensor = torch.zeros(batch_shape, dtype=dtype, device=device):
tensor 的全零张量,其形状由 batch_shape 指定,数据类型由 dtype 指定,设备由 device 指定。mask = torch.ones((b, h, w), dtype=torch.bool, device=device):
mask 的全一张量,形状为 (b, h, w),数据类型为布尔型 (dtype=torch.bool),设备与 tensor 相同。使用 for 循环遍历 tensor_list 中的每个图像以及相应的 pad_img 和 m:
pad_img 是 tensor 中的一部分,用于存储图像数据。通过切片操作将图像数据从原始图像 img 复制到 pad_img 中,以适应不同大小的图像。m 是一个布尔掩码,用于表示图像的有效区域。通过将 m 的部分设置为 False,将未使用的部分标记为无效。最后,通过检查 tensor_list 中的图像是否都具有相同的尺寸,如果是则返回一个 NestedTensor 对象,否则引发 ValueError 异常。
总之,这段代码的功能是将不同大小的图像放置在一个批次中,使用全零张量来存储图像数据,并使用布尔掩码来表示图像的有效区域。这有助于处理不同大小的图像输入到神经网络中。
NestedTensor 对象- # _onnx_nested_tensor_from_tensor_list() is an implementation of
- # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
- @torch.jit.unused
- def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
- max_size = []
- for i in range(tensor_list[0].dim()):
- max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
- max_size.append(max_size_i)
- max_size = tuple(max_size)
-
- # work around for
- # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
- # m[: img.shape[1], :img.shape[2]] = False
- # which is not yet supported in onnx
- padded_imgs = []
- padded_masks = []
- for img in tensor_list:
- padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
- padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
- padded_imgs.append(padded_img)
-
- m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
- padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
- padded_masks.append(padded_mask.to(torch.bool))
-
- tensor = torch.stack(padded_imgs)
- mask = torch.stack(padded_masks)
-
- return NestedTensor(tensor, mask=mask)
_onnx_nested_tensor_from_tensor_list() 是一个 ONNX 跟踪(tracing)支持的 nested_tensor_from_tensor_list() 的实现版本。它的作用是将具有不同大小的图像张量列表转换为一个批量张量和一个掩码张量,以创建一个 NestedTensor 对象,同时考虑了 ONNX 跟踪的要求。最后,函数返回一个 NestedTensor 对象,其中包含批量张量和掩码张量,以表示具有不同大小的图像。这个实现考虑了 ONNX 跟踪的要求,以确保在 ONNX 中运行时能够正确处理图像。
- # _onnx_nested_tensor_from_tensor_list() is an implementation of
- # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
- @torch.jit.unused
- def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
- max_size = []
- for i in range(tensor_list[0].dim()):
- max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
- max_size.append(max_size_i)
- max_size = tuple(max_size)
_onnx_nested_tensor_from_tensor_list() 是一个 ONNX 跟踪(tracing)支持的函数,它用于从张量列表创建一个 NestedTensor 对象。它的工作方式是找到输入张量列表中的每个维度的最大大小,然后使用填充操作将所有张量调整为相同的大小,最后将它们组合成一个批量张量和一个掩码张量。
以下是函数的主要步骤:
创建一个空列表 max_size,用于存储每个维度的最大大小。
使用 for 循环遍历输入张量列表中的每个维度。对于每个维度,计算该维度上所有张量的最大值,并将结果存储为整数。这里使用 torch.max 函数来找到最大值。
将计算得到的每个维度的最大大小组合成一个元组 max_size。
创建两个空列表 padded_imgs 和 padded_masks,用于存储填充后的图像张量和掩码张量。
使用 for 循环遍历输入张量列表中的每个图像。对于每个图像,计算需要添加的填充量,并使用 torch.nn.functional.pad 函数将图像进行填充,以使其大小与 max_size 相同。同时,还创建一个掩码张量,并进行相应的填充,将空白区域标记为 True。
使用 torch.stack 函数将填充后的图像张量列表和掩码张量列表分别堆叠成批量张量和掩码张量。
返回一个包含批量张量和掩码张量的 NestedTensor 对象,表示具有不同大小的图像。
这个函数的关键点在于计算每个维度的最大大小,并进行填充操作,以确保所有图像具有相同的大小,以便在后续的处理中能够正确处理它们。这个函数被设计为支持 ONNX 跟踪,以确保在 ONNX 中运行时能够正常工作。
- # work around for
- # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
- # m[: img.shape[1], :img.shape[2]] = False
- # which is not yet supported in onnx
- padded_imgs = []
- padded_masks = []
- for img in tensor_list:
- padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
- padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
- padded_imgs.append(padded_img)
-
- m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
- padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
- padded_masks.append(padded_mask.to(torch.bool))
上述代码块是 _onnx_nested_tensor_from_tensor_list 函数的一部分,用于处理张量的填充,以确保它们具有相同的大小。以下是代码的解释:
创建两个空列表 padded_imgs 和 padded_masks,用于存储填充后的图像张量和掩码张量。
使用 for 循环遍历输入张量列表 tensor_list 中的每个图像。
对于每个图像 img,计算需要添加的填充量 padding。padding 是一个包含三个元素的列表,分别表示在图像的三个维度(高度、宽度和通道数)上需要添加的填充量。这些填充量是通过将 max_size 减去当前图像的形状来计算的,确保所有图像都将填充为相同的大小。
使用 torch.nn.functional.pad 函数对图像 img 进行填充,以使其大小与 max_size 相同。填充的方式是在图像的高度、宽度和通道维度上分别添加填充量。填充后的图像被添加到 padded_imgs 列表中。
创建一个与图像 img 相同形状的全零张量 m,并将其数据类型设置为整数(dtype=torch.int)。这个张量将用作掩码。
使用 torch.nn.functional.pad 函数对掩码 m 进行填充,只在高度和宽度维度上添加填充。填充的方式是使用常数填充,将空白区域标记为 1(True),表示这些区域没有值。
将填充后的图像张量和填充后的掩码张量添加到 padded_imgs 和 padded_masks 列表中,并将掩码张量的数据类型转换为布尔型(to(torch.bool))。
这样,padded_imgs 中的图像张量都具有相同的大小,并且可以在后续的处理中使用,而 padded_masks 中的掩码张量用于标记图像中的有效区域。这个处理过程是为了适应 ONNX 的需求,因为 ONNX 不支持直接的张量切片和掩码操作。
tensor 和掩码张量 mask- tensor = torch.stack(padded_imgs)
- mask = torch.stack(padded_masks)
-
- return NestedTensor(tensor, mask=mask)
在这段代码中,首先使用 torch.stack 函数将填充后的图像张量列表 padded_imgs 和填充后的掩码张量列表 padded_masks 合并成一个新的张量 tensor 和掩码张量 mask。
tensor 是一个包含了所有填充后的图像的张量,它们具有相同的大小。这个张量用于表示一批图像,每个图像具有相同的尺寸。mask 是一个包含了所有填充后的掩码的张量,它们也具有相同的大小。这个掩码张量用于标记图像中的有效区域,即哪些像素是有值的,哪些像素是填充的(没有值)。最后,函数返回一个 NestedTensor 对象,该对象包含了合并后的张量 tensor 和掩码张量 mask。这个 NestedTensor 对象可以用于表示一批具有不同大小的图像,并在模型中进行处理,同时保持了相同的大小以便于处理。
- def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
- builtin_print = __builtin__.print
-
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- if is_master or force:
- builtin_print(*args, **kwargs)
-
- __builtin__.print = print
这个函数用于在分布式环境中设置打印行为,以便在非主进程中禁用打印。以下是该函数的主要功能:
首先,它导入了内置的 builtins 模块,并将内置的 print 函数引用命名为 builtin_print。
然后,它定义了一个名为 print 的新函数,该函数接受与内置的 print 函数相同的参数和关键字参数。但它还接受一个名为 force 的关键字参数,该参数默认值为 False。
在新的 print 函数中,它首先检查 is_master 的值(该值表示当前进程是否为主进程)。如果 is_master 为 True 或者 force 参数为 True,则调用内置的 print 函数 builtin_print 打印传入的参数。
最后,它将新的 print 函数赋值给内置的 print 函数,从而在之后的代码中使用新的 print 函数来实现打印操作。
通过这种方式,可以在分布式环境中控制哪些进程可以进行打印操作,以避免在非主进程中产生大量的输出信息。
dist_avail_and_initialized() 函数- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
这个函数用于检查当前环境是否支持分布式计算,并且是否已经初始化了分布式计算环境。下面是函数的主要功能:
首先,它使用 dist.is_available() 函数来检查当前环境是否支持分布式计算。如果不支持,返回 False。
接着,它使用 dist.is_initialized() 函数来检查分布式计算环境是否已经初始化。如果未初始化,返回 False。
最后,如果分布式计算环境已经初始化且当前环境支持分布式计算,那么函数返回 True,表示分布式环境已经准备好使用。
这个函数通常用于在进行分布式训练等任务之前,先检查是否满足了必要的分布式计算条件。如果条件不满足,可以采取相应的措施或者提供错误信息。
ist.get_world_size() 函数- def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
这个函数用于获取分布式计算环境中的世界大小(world size)。世界大小表示了分布式计算中的进程数量,也就是同时运行的任务或者计算节点的数量。下面是函数的主要功能:
首先,它调用了 is_dist_avail_and_initialized() 函数,以检查当前环境是否支持分布式计算并且是否已经初始化。如果不满足这些条件,函数返回默认的世界大小为1,表示单机环境。
如果当前环境支持分布式计算且已经初始化,那么函数使用 dist.get_world_size() 函数来获取实际的世界大小,即分布式计算中的进程数量。
这个函数的目的是根据当前的分布式计算环境动态获取世界大小,以便在分布式计算任务中正确配置和协调不同的进程。如果分布式计算环境未初始化,它会默认将世界大小设置为1,以便在单机环境中运行。
- def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
这个函数用于获取当前进程在分布式计算环境中的排名(rank)。排名表示当前进程在分布式计算中的唯一标识,通常从0开始递增,表示不同的计算节点或任务。
函数的主要功能如下:
首先,它调用了 is_dist_avail_and_initialized() 函数,以检查当前环境是否支持分布式计算并且是否已经初始化。如果不满足这些条件,函数返回默认的排名为0,表示单机环境中的唯一进程。
如果当前环境支持分布式计算且已经初始化,那么函数使用 dist.get_rank() 函数来获取当前进程的排名。
这个函数的目的是根据当前的分布式计算环境获取当前进程的排名,以便在分布式计算任务中知道每个进程的唯一标识。如果分布式计算环境未初始化,它会默认将排名设置为0,以便在单机环境中运行。
- def is_main_process():
- return get_rank() == 0
这个函数的目的是检查当前进程是否是主进程。在分布式计算任务中,通常只有一个进程被指定为主进程,用于执行一些全局性的任务,例如模型的保存和日志记录。其他进程则通常用于模型的训练和推理等任务。
这个函数的实现方式很简单,它通过调用 get_rank() 函数来获取当前进程的排名,然后将排名与0进行比较。如果排名为0,说明当前进程是主进程,函数返回True;否则,返回False。
在分布式计算中,通常只有排名为0的进程被认定为主进程,因此这个函数用于确定当前进程是否是主进程,以便在需要执行全局任务时进行判断
save_on_master ()函数- def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
save_on_master 函数用于在主进程上保存模型或其他对象。在分布式计算中,通常只有主进程负责保存模型参数和其他重要信息,以确保保存的模型是完整的且不会发生冲突。
这个函数的实现非常简单,它首先调用 is_main_process() 函数来检查当前进程是否是主进程。如果当前进程是主进程,它会调用 torch.save(*args, **kwargs) 来保存模型或其他对象。否则,如果当前进程不是主进程,函数将不执行任何操作,从而避免了在非主进程上保存模型。
这个函数的目的是确保只有主进程保存模型,以防止多个进程同时保存模型时可能发生的冲突或文件覆盖问题。
init_distributed_mode() 函数- def init_distributed_mode(args):
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- args.distributed = False
- return
上面的代码片段是 init_distributed_mode 函数的一部分,用于在确定了分布式计算环境后进行相关的初始化操作。具体操作包括:
将 args.distributed 设置为 True,表示当前处于分布式模式。
使用 torch.cuda.set_device(args.gpu) 将当前 GPU 设备设置为 args.gpu。这是为了确保每个进程使用正确的 GPU 设备。
设置 args.dist_backend 为 'nccl',这是 PyTorch 中用于分布式计算的后端。
使用 torch.distributed.init_process_group 初始化分布式进程组。这个函数会根据传入的参数初始化分布式计算的通信机制,包括进程排名、通信后端、通信初始化方法、世界大小等。
调用 torch.distributed.barrier() 来确保所有进程都已初始化完成。分布式计算中,通常需要所有进程都达到某个同步点后才能继续执行后续操作。
最后,调用 setup_for_distributed 函数来设置打印输出,确保只有主进程(排名为0的进程)才会打印信息,而其他进程不会打印。
这些操作都是为了确保在分布式计算环境中的初始化和同步,以便后续的分布式训练能够顺利进行。
init_distributed_mode 函数的一部分- args.distributed = True
-
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}'.format(
- args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
上面的代码片段是 init_distributed_mode 函数的一部分,用于在确定了分布式计算环境后进行相关的初始化操作。具体操作包括:
将 args.distributed 设置为 True,表示当前处于分布式模式。
使用 torch.cuda.set_device(args.gpu) 将当前 GPU 设备设置为 args.gpu。这是为了确保每个进程使用正确的 GPU 设备。
设置 args.dist_backend 为 'nccl',这是 PyTorch 中用于分布式计算的后端。
使用 torch.distributed.init_process_group 初始化分布式进程组。这个函数会根据传入的参数初始化分布式计算的通信机制,包括进程排名、通信后端、通信初始化方法、世界大小等。
调用 torch.distributed.barrier() 来确保所有进程都已初始化完成。分布式计算中,通常需要所有进程都达到某个同步点后才能继续执行后续操作。
最后,调用 setup_for_distributed 函数来设置打印输出,确保只有主进程(排名为0的进程)才会打印信息,而其他进程不会打印。
这些操作都是为了确保在分布式计算环境中的初始化和同步,以便后续的分布式训练能够顺利进行。
- @torch.no_grad()
- def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- if target.numel() == 0:
- return [torch.zeros([], device=output.device)]
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
这段代码计算了模型的预测输出 output 与真实标签 target 之间的精度(accuracy),并且支持不同的精度计算,即可以计算前k个预测的精度。以下是对每行代码的详细解释:
def accuracy(output, target, topk=(1,)):
accuracy 的函数,它接受三个参数:output(模型的预测输出)、target(真实标签)、topk(一个元组,包含要计算的精度值,默认为1)。if target.numel() == 0:
target 是否为空(即没有元素)。如果没有真实标签,意味着没有可以计算精度的数据,因此返回一个包含零值的张量作为结果,其设备与 output 相同。maxk = max(topk)
topk 中的最大值,表示要计算的精度的最大值。batch_size = target.size(0)
target 的第一维大小,通常表示批处理的大小。_, pred = output.topk(maxk, 1, True, True)
torch.topk 函数找到 output 张量中每个样本的前 maxk 个预测值(按值降序排列)和相应的索引。这里的 _ 是一个占位符,因为我们不需要关注预测值,只需要索引。pred = pred.t()
maxk 个预测的类别索引。correct = pred.eq(target.view(1, -1).expand_as(pred))
correct,其形状与 pred 相同,用于表示每个预测是否正确。这里的操作将 target 变换为与 pred 相同的形状,然后与 pred 逐元素比较,返回一个布尔张量,表示哪些预测是正确的。res = []
res 以存储精度结果。for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
correct 布尔张量中选择前 k 个预测的结果,并将其展平成一维张量。然后将布尔值转换为浮点数(True 变为1,False 变为0),并计算它们的和,表示正确的预测数量。res.append(correct_k.mul_(100.0 / batch_size))
res 列表中。首先将正确的预测数量除以批处理大小,然后乘以100,以计算精度百分比。这个值被添加到 res 列表中。return res
res,每个元素表示对应精度的百分比精度值。综上所述,这段代码计算了模型的精度,并且通过参数 topk 可以选择计算不同精度值,最终返回一个列表,包含了不同精度值的百分比精度。这在深度学习中常用于模型性能的评估。
- def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
- # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
- """
- Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
- This will eventually be supported natively by PyTorch, and this
- class can go away.
- """
- if version.parse(torchvision.__version__) < version.parse('0.7'):
- if input.numel() > 0:
- return torch.nn.functional.interpolate(
- input, size, scale_factor, mode, align_corners
- )
-
- output_shape = _output_size(2, input, size, scale_factor)
- output_shape = list(input.shape[:-2]) + list(output_shape)
- return _new_empty_tensor(input, output_shape)
- else:
- return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
这是一个用于插值的函数,通常用于调整图像或特征图的尺寸。以下是函数的主要参数和功能:
input(Tensor):输入张量,可以是图像或特征图。
size(Optional[List[int]]):目标输出的空间大小,通常表示为 [H, W],其中 H 表示高度,W 表示宽度。这是一个可选参数。
scale_factor(Optional[float]):尺度因子,用于确定输出尺寸相对于输入尺寸的比例。例如,如果 scale_factor=0.5,则输出尺寸将是输入尺寸的一半。这是一个可选参数。
mode(str):插值模式,用于确定如何进行插值。常见的模式包括:
"nearest":最近邻插值,使用最近的像素值进行插值。"bilinear":双线性插值,使用四个最近的像素值进行插值。"bicubic":双三次插值,使用16个最近的像素值进行插值。align_corners(Optional[bool]):一个布尔值,确定是否要对齐角点。通常在双线性插值中使用,以确定是否将插值网格的四个角点对齐到输入和输出的角点。这是一个可选参数。
函数的主要作用是根据输入参数调整输入张量的尺寸,并返回一个新的张量,该张量具有指定的输出尺寸和插值模式。如果输入张量的元素数量为0(即空输入),则函数将返回一个具有指定输出尺寸的空张量。
需要注意的是,函数的行为在不同版本的PyTorch和TorchVision中可能会有所不同,因此根据使用的库版本选择适当的用法。