• pytorch Nvidia 数据预处理加速


    目录

    安装 不支持Windows:

    官方说明:

    预处理加速:

    学习笔记:


    参考:

    深度学习预处理工具---DALI详解_nvidia.dali.fn_扫地的小何尚的博客-CSDN博客

    安装 不支持Windows:

    官方说明:

    Installation — NVIDIA DALI 1.30.0 documentation

    pip install nvidia-pyindex
    pip install nvidia-dali-cuda110


    import nvidia.dali.ops
    import nvidia.dali.types
     
    from nvidia.dali.pipeline import Pipeline
    from nvidia.dali.plugin.pytorch import DALIGenericIterator
     

    官网下载地址:看起来么有windows版本,

    Index of /compute/redist///nvidia-dali-cuda110

    预处理加速:

    Nvidia Dali: 强大的数据增强库_笔记大全_设计学院

    学习笔记:

    对于深度学习任务,训练速度决定了模型的迭代速度,而训练速度又取决于数据预处理和网络的前向和后向耗时。
    对于识别任务,batch size通常较大,并且需要做数据增强,因此常常导致训练速度的瓶颈在数据读取和预处理上,尤其对于小网络而言。
    对于数据读取耗时的提升,粗暴且有效的解决办法是使用固态硬盘,或者将数据直接拷贝至/tmp文件夹(内存空间换时间)。
    对于数据预处理的耗时,则可以通过使用Nvidia官方开发的Dali预处理加速工具包,将预处理放在cpu/gpu上进行加速。pytorch1.6版本内置了Dali,无需自己安装。

    官方的Dali交程较为简单,实际训练通常要根据任务需要自定义Dataloader,并于分布式训练结合使用。这里将展示一个使用Dali定义DataLoader的例子,功能是返回序列图像,并对序列图像做常见的统一预处理操作。
    `

    1. from nvidia.dali.plugin.pytorch import DALIGenericIterator
    2. from nvidia.dali.types import DALIImageType
    3. import cv2
    4. from nvidia.dali.plugin.pytorch import DALIClassificationIterator
    5. from nvidia.dali.pipeline import Pipeline
    6. import nvidia.dali.ops as ops
    7. import nvidia.dali.types as types
    8. from sklearn.utils import shuffle
    9. import numpy as np
    10. from torchvision import transforms
    11. import torch.utils.data as torchdata
    12. import random
    13. from pathlib import Path
    14. import torch
    15. class TRAIN_INPUT_ITER(object):
    16. def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=True):
    17. self.batch_size = batch_size
    18. self.num_class = num_class
    19. self.seq_len = seq_len
    20. self.sample_rate = sample_rate
    21. self.num_shards = num_shards
    22. self.shard_id = shard_id
    23. self.train = is_training
    24. self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
    25. self.root_dir = root_dir
    26. with open(list_file,'r') as f:
    27. self.ori_lines = f.readlines()
    28. def __iter__(self):
    29. self.i = 0
    30. bucket = len(self.ori_lines)//self.num_shards
    31. self.n = bucket
    32. return self
    33. def __next__(self):
    34. batch = [[] for _ in range(self.seq_len)]
    35. labels = []
    36. for _ in range(self.batch_size):
    37. # self.sample_rate = random.randint(1,2)
    38. if self.train and self.i % self.n == 0:
    39. bucket = len(self.ori_lines)//self.num_shards
    40. self.ori_lines= shuffle(self.ori_lines, random_state=0)
    41. self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
    42. line = self.lines[self.i].strip()
    43. dir_name,start_f,end_f, label = line.split(' ')
    44. start_f = int(start_f)
    45. end_f = int(end_f)
    46. label = int(label)
    47. begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
    48. begin_frame = max(1,begin_frame)
    49. last_frame = None
    50. for k in range(self.seq_len):
    51. filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
    52. if filename.exists():
    53. f = open(filename,'rb')
    54. last_frame = filename
    55. elif last_frame is not None:
    56. f = open(last_frame,'rb')
    57. else:
    58. print('{} does not exist'.format(filename))
    59. raise IOError
    60. batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
    61. if random.randint(0,1)%2 == 0:
    62. end_frame = start_f + random.randint(0,self.sample_rate*self.seq_len//2)
    63. begin_frame = max(1,end_frame-self.sample_rate*self.seq_len)
    64. else:
    65. begin_frame = end_f - random.randint(0,self.sample_rate*self.seq_len//2)
    66. begin_frame = max(1,begin_frame)
    67. end_frame = begin_frame + self.sample_rate*self.seq_len
    68. last_frame = None
    69. for k in range(self.seq_len):
    70. filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
    71. if filename.exists():
    72. f = open(filename,'rb')
    73. last_frame = filename
    74. elif last_frame is not None:
    75. f = open(last_frame,'rb')
    76. else:
    77. print('{} does not exist'.format(filename))
    78. raise IOError
    79. batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
    80. labels.append(np.array([label], dtype = np.uint8))
    81. if label==8 or label == 9:
    82. labels.append(np.array([label], dtype = np.uint8))
    83. else:
    84. labels.append(np.array([self.num_class-1], dtype = np.uint8))
    85. self.i = (self.i + 1) % self.n
    86. return (batch, labels)
    87. next = __next__
    88. class VAL_INPUT_ITER(object):
    89. def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=False):
    90. self.batch_size = batch_size
    91. self.num_class = num_class
    92. self.seq_len = seq_len
    93. self.sample_rate = sample_rate
    94. self.num_shards = num_shards
    95. self.shard_id = shard_id
    96. self.train = is_training
    97. self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
    98. self.root_dir = root_dir
    99. with open(list_file,'r') as f:
    100. self.ori_lines = f.readlines()
    101. self.ori_lines= shuffle(self.ori_lines, random_state=0)
    102. def __iter__(self):
    103. self.i = 0
    104. bucket= len(self.ori_lines)//self.num_shards
    105. self.n = bucket
    106. return self
    107. def __next__(self):
    108. batch = [[] for _ in range(self.seq_len)]
    109. labels = []
    110. for _ in range(self.batch_size):
    111. # self.sample_rate = random.randint(1,2)
    112. if self.train and self.i % self.n == 0:
    113. bucket = len(self.ori_lines)//self.num_shards
    114. self.ori_lines= shuffle(self.ori_lines, random_state=0)
    115. self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
    116. if self.i % self.n == 0:
    117. bucket = len(self.ori_lines)//self.num_shards
    118. self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
    119. line = self.lines[self.i].strip()
    120. dir_name,start_f,end_f, label = line.split(' ')
    121. start_f = int(start_f)
    122. end_f = int(end_f)
    123. label = int(label)
    124. begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
    125. begin_frame = max(1,begin_frame)
    126. last_frame = None
    127. for k in range(self.seq_len):
    128. filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
    129. if filename.exists():
    130. f = open(filename,'rb')
    131. last_frame = filename
    132. elif last_frame is not None:
    133. f = open(last_frame,'rb')
    134. else:
    135. print('{} does not exist'.format(filename))
    136. raise IOError
    137. batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
    138. labels.append(np.array([label], dtype = np.uint8))
    139. self.i = (self.i + 1) % self.n
    140. return (batch, labels)
    141. next = __next__
    142. class HybridPipe(Pipeline):
    143. def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards,shard_id,root_dir, list_file, num_threads, device_id=0, dali_cpu=True,size = (224,224),is_gray = True,is_training = True):
    144. super(HybridPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
    145. if is_training:
    146. self.external_data = TRAIN_INPUT_ITER(batch_size//2, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
    147. else:
    148. self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
    149. # self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
    150. self.seq_len = seq_len
    151. self.training = is_training
    152. self.iterator = iter(self.external_data)
    153. self.inputs = [ops.ExternalSource() for _ in range(seq_len)]
    154. self.input_labels = ops.ExternalSource()
    155. self.is_gray = is_gray
    156. decoder_device = 'cpu' if dali_cpu else 'mixed'
    157. self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
    158. if self.is_gray:
    159. self.space_converter = ops.ColorSpaceConversion(device='gpu',image_type=types.RGB,output_type=types.GRAY)
    160. self.resize = ops.Resize(device='gpu', size=size)
    161. self.cast_fp32 = ops.Cast(device='gpu',dtype = types.FLOAT)
    162. if self.training:
    163. self.crop_coin = ops.CoinFlip(probability=0.5)
    164. self.crop_pos_x = ops.Uniform(range=(0., 1.))
    165. self.crop_pos_y = ops.Uniform(range=(0., 1.))
    166. self.crop_h = ops.Uniform(range=(256*0.85,256))
    167. self.crop_w = ops.Uniform(range=(256*0.85,256))
    168. self.crmn = ops.CropMirrorNormalize(device="gpu",output_layout=types.NHWC)
    169. self.u_rotate = ops.Uniform(range=(-8, 8))
    170. self.rotate = ops.Rotate(device='gpu',keep_size=True)
    171. self.brightness = ops.Uniform(range=(0.9,1.1))
    172. self.contrast = ops.Uniform(range=(0.9,1.1))
    173. self.saturation = ops.Uniform(range=(0.9,1.1))
    174. self.hue = ops.Uniform(range=(-0.3,0.3))
    175. self.color_jitter = ops.ColorTwist(device='gpu')
    176. else:
    177. self.crmn = ops.CropMirrorNormalize(device="gpu",crop=(224,224),output_layout=types.NHWC)
    178. def define_graph(self):
    179. self.batch_data = [i() for i in self.inputs]
    180. self.labels = self.input_labels()
    181. out = self.decode(self.batch_data)
    182. out = [out_elem.gpu() for out_elem in out]
    183. if self.training:
    184. out = self.color_jitter(out,brightness=self.brightness(),contrast=self.contrast())
    185. if self.is_gray:
    186. out = self.space_converter(out)
    187. if self.training:
    188. out = self.rotate(out,angle=self.u_rotate())
    189. out = self.crmn(out,crop_h=self.crop_h(),crop_w=self.crop_w(),crop_pos_x=self.crop_pos_x(),crop_pos_y=self.crop_pos_y(),mirror=self.crop_coin())
    190. else:
    191. out = self.crmn(out)
    192. out = self.resize(out)
    193. if not self.training:
    194. out = self.cast_fp32(out)
    195. return (*out, self.labels)
    196. def iter_setup(self):
    197. try:
    198. (batch_data, labels) = self.iterator.next()
    199. for i in range(self.seq_len):
    200. self.feed_input(self.batch_data[i], batch_data[i])
    201. self.feed_input(self.labels, labels)
    202. except StopIteration:
    203. self.iterator = iter(self.external_data)
    204. raise StopIteration
    205. def dali_loader(batch_size,
    206. num_class,
    207. seq_len,
    208. sample_rate,
    209. num_shards,
    210. shard_id,
    211. root_dir,
    212. list_file,
    213. num_workers,
    214. device_id,
    215. dali_cpu=True,
    216. size = (224,224),
    217. is_gray = True,
    218. is_training=True):
    219. print('##########',root_dir)
    220. pipe = HybridPipe(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,
    221. list_file,num_workers,device_id=device_id,
    222. dali_cpu=dali_cpu,size = size,is_gray=is_gray,is_training=is_training)
    223. # pipe.build()
    224. names = []
    225. for i in range(seq_len):
    226. names.append(f'data{i}')
    227. names.append('label')
    228. print('##############',names)
    229. loader = DALIGenericIterator(pipe,names,pipe.external_data.n,last_batch_padded=True, fill_last_batch=True)
    230. return loade
  • 相关阅读:
    如何查询IP地址的位置?
    如何搭建一部引人入胜的短剧小程序
    【Linux】进程控制基础知识
    Spring注解驱动之声明式事务源码分析
    408真题-2021
    【第3章】MyBatis-Plus持久层接口之Service Interface(上)
    优先级队列(堆)【Java】
    [补题记录] Complete the Permutation(贪心、set)
    金仓数据库 KingbaseES 插件参考手册(17. dbms_metadata)
    1459. 矩形面积
  • 原文地址:https://blog.csdn.net/jacke121/article/details/133692941