• RobustVideoMatting 训练预测笔记


    改为了推理图片,文件夹的图片尺寸必须一样,否则会报错

    针对复杂场景,效果也不好,比如被另一个人遮挡,前面还挂了围脖,背了包包,抱着小孩

    1. """
    2. python inference.py \
    3. --variant mobilenetv3 \
    4. --checkpoint "CHECKPOINT" \
    5. --device cuda \
    6. --input-source "input.mp4" \
    7. --output-type video \
    8. --output-composition "composition.mp4" \
    9. --output-alpha "alpha.mp4" \
    10. --output-foreground "foreground.mp4" \
    11. --output-video-mbps 4 \
    12. --seq-chunk 1
    13. """
    14. import torch
    15. import os
    16. from torch.utils.data import DataLoader
    17. from torchvision import transforms
    18. from typing import Optional, Tuple
    19. from tqdm.auto import tqdm
    20. from inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter
    21. def convert_video(model,
    22. input_source: str,
    23. input_resize: Optional[Tuple[int, int]] = None,
    24. downsample_ratio: Optional[float] = None,
    25. output_type: str = 'video',
    26. output_composition: Optional[str] = None,
    27. output_alpha: Optional[str] = None,
    28. output_foreground: Optional[str] = None,
    29. output_video_mbps: Optional[float] = None,
    30. seq_chunk: int = 1,
    31. num_workers: int = 0,
    32. progress: bool = True,
    33. device: Optional[str] = None,
    34. dtype: Optional[torch.dtype] = None):
    35. assert downsample_ratio is None or (downsample_ratio > 0 and downsample_ratio <= 1), 'Downsample ratio must be between 0 (exclusive) and 1 (inclusive).'
    36. assert any([output_composition, output_alpha, output_foreground]), 'Must provide at least one output.'
    37. assert output_type in ['video', 'png_sequence'], 'Only support "video" and "png_sequence" output modes.'
    38. assert seq_chunk >= 1, 'Sequence chunk must be >= 1'
    39. assert num_workers >= 0, 'Number of workers must be >= 0'
    40. # Initialize transform
    41. if input_resize is not None:
    42. transform = transforms.Compose([
    43. transforms.Resize(input_resize[::-1]),
    44. transforms.ToTensor()
    45. ])
    46. else:
    47. transform = transforms.ToTensor()
    48. # Initialize reader
    49. if os.path.isfile(input_source):
    50. source = VideoReader(input_source, transform)
    51. else:
    52. source = ImageSequenceReader(input_source, transform)
    53. reader = DataLoader(source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers)
    54. # Initialize writers
    55. if output_type == 'video':
    56. frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
    57. output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
    58. if output_composition is not None:
    59. writer_com = VideoWriter(
    60. path=output_composition,
    61. frame_rate=frame_rate,
    62. bit_rate=int(output_video_mbps * 1000000))
    63. if output_alpha is not None:
    64. writer_pha = VideoWriter(
    65. path=output_alpha,
    66. frame_rate=frame_rate,
    67. bit_rate=int(output_video_mbps * 1000000))
    68. if output_foreground is not None:
    69. writer_fgr = VideoWriter(
    70. path=output_foreground,
    71. frame_rate=frame_rate,
    72. bit_rate=int(output_video_mbps * 1000000))
    73. else:
    74. if output_composition is not None:
    75. writer_com = ImageSequenceWriter(output_composition, 'png')
    76. if output_alpha is not None:
    77. writer_pha = ImageSequenceWriter(output_alpha, 'png')
    78. if output_foreground is not None:
    79. writer_fgr = ImageSequenceWriter(output_foreground, 'png')
    80. # Inference
    81. model = model.eval()
    82. if device is None or dtype is None:
    83. param = next(model.parameters())
    84. dtype = param.dtype
    85. device = param.device
    86. if (output_composition is not None) and (output_type == 'video'):
    87. bgr = torch.tensor([120, 255, 155], device=device, dtype=dtype).div(255).view(1, 1, 3, 1, 1)
    88. try:
    89. with torch.no_grad():
    90. bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
    91. rec = [None] * 4
    92. for src in reader:
    93. if downsample_ratio is None:
    94. downsample_ratio = auto_downsample_ratio(*src.shape[2:])
    95. src = src.to(device, dtype, non_blocking=True).unsqueeze(0) # [B, T, C, H, W]
    96. fgr, pha, *rec = model(src, *rec, downsample_ratio)
    97. if output_foreground is not None:
    98. writer_fgr.write(fgr[0])
    99. if output_alpha is not None:
    100. writer_pha.write(pha[0])
    101. if output_composition is not None:
    102. if output_type == 'video':
    103. com = fgr * pha + bgr * (1 - pha)
    104. else:
    105. fgr = fgr * pha.gt(0)
    106. com = torch.cat([fgr, pha], dim=-3)
    107. writer_com.write(com[0])
    108. bar.update(src.size(1))
    109. finally:
    110. # Clean up
    111. if output_composition is not None:
    112. writer_com.close()
    113. if output_alpha is not None:
    114. writer_pha.close()
    115. if output_foreground is not None:
    116. writer_fgr.close()
    117. def auto_downsample_ratio(h, w):
    118. """
    119. Automatically find a downsample ratio so that the largest side of the resolution be 512px.
    120. """
    121. return min(512 / max(h, w), 1)
    122. class Converter:
    123. def __init__(self, variant: str, checkpoint: str, device: str):
    124. self.model = MattingNetwork(variant).eval().to(device)
    125. self.model.load_state_dict(torch.load(checkpoint, map_location=device))
    126. self.model = torch.jit.script(self.model)
    127. self.model = torch.jit.freeze(self.model)
    128. self.device = device
    129. def convert(self, *args, **kwargs):
    130. convert_video(self.model, device=self.device, dtype=torch.float32, *args, **kwargs)
    131. if __name__ == '__main__':
    132. import argparse
    133. from model import MattingNetwork
    134. """
    135. python inference.py \
    136. --variant mobilenetv3 \
    137. --checkpoint "CHECKPOINT" \
    138. --device cuda \
    139. --input-source "input.mp4" \
    140. --output-type video \
    141. --output-composition "composition.mp4" \
    142. --output-alpha "alpha.mp4" \
    143. --output-foreground "foreground.mp4" \
    144. --output-video-mbps 4 \
    145. --seq-chunk 1
    146. """
    147. parser = argparse.ArgumentParser()
    148. parser.add_argument('--variant', type=str, default='resnet50', choices=['mobilenetv3', 'resnet50'])
    149. parser.add_argument('--checkpoint', type=str, default=r'D:\project\fenge\jacke121-rvm_128_json\model_a\rvm_resnet50.pth')
    150. parser.add_argument('--device', type=str,default='cuda')
    151. parser.add_argument('--input-source', type=str, default=r'C:\Users\Administrator\Documents\WeChat Files\libanggeng\FileStorage\File\2023-11\koutu\weilanliandai\aa')
    152. parser.add_argument('--input-resize', type=int, default=None, nargs=2)
    153. parser.add_argument('--downsample-ratio', type=float)
    154. parser.add_argument('--output-composition', type=str,default='output-composition')
    155. parser.add_argument('--output-alpha', type=str,default='output-alpha')
    156. parser.add_argument('--output-foreground', type=str,default='output-foreground')
    157. parser.add_argument('--output-type', type=str, default='png_sequence', choices=['video', 'png_sequence'])
    158. parser.add_argument('--output-video-mbps', type=int, default=1)
    159. parser.add_argument('--seq-chunk', type=int, default=1)
    160. parser.add_argument('--num-workers', type=int, default=0)
    161. parser.add_argument('--disable-progress', action='store_true')
    162. args = parser.parse_args()
    163. converter = Converter(args.variant, args.checkpoint, args.device)
    164. converter.convert(
    165. input_source=args.input_source,
    166. input_resize=args.input_resize,
    167. downsample_ratio=args.downsample_ratio,
    168. output_type=args.output_type,
    169. output_composition=args.output_composition,
    170. output_alpha=args.output_alpha,
    171. output_foreground=args.output_foreground,
    172. output_video_mbps=args.output_video_mbps,
    173. seq_chunk=args.seq_chunk,
    174. num_workers=args.num_workers,
    175. progress=not args.disable_progress
    176. )

  • 相关阅读:
    如何搭建一个前端脚手架
    重学设计模式之 装饰者模式
    SpringBoot + Redis的Bitmap实现活跃用户统计
    21天学Python --- 打卡7:Spider爬虫入门
    C++笔记之C++、C语言、PISIX、拿到线程函数的返回值的所有方法
    【RuoYi-Cloud-Plus】学习笔记 02 - Nacos(二)寻址机制之文件寻址分析
    【AGC】AGC鉴权认证模式获取clientToken的方法
    脉冲神经网络原理及应用,脉冲神经网络项目名称
    EtherCAT主站转Ethernet/IP网关
    @Linux系统安装部署Sql Server(MSSQL)
  • 原文地址:https://blog.csdn.net/jacke121/article/details/134432534