• 最新阿里开源视频生成框架Tora部署


    Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。

    Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。

    通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。

    Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。

    Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。

    github项目地址:https://github.com/alibaba/Tora。

    一、环境安装

    1、python环境

    建议安装python版本在3.10以上。

    2、pip库安装

    pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

    cd modules/SwissArmyTransformer

    pip install -e .

    cd ../../sat

    pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

    3、CogVideoX-5b模型下载

    git lfs install

    git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git

    4、Tora t2v模型下载

    https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt

    、功能测试

    1、运行测试

    (1)python代码调用测试

    1. import argparse
    2. import gc
    3. import json
    4. import math
    5. import os
    6. import pickle
    7. from pathlib import Path
    8. from typing import List, Union
    9. import cv2
    10. import imageio
    11. import numpy as np
    12. import torch
    13. import torchvision.transforms as TT
    14. from arguments import get_args
    15. from diffusion_video import SATVideoDiffusionEngine
    16. from einops import rearrange, repeat
    17. from omegaconf import ListConfig
    18. from torchvision.io import write_video
    19. from torchvision.transforms import InterpolationMode
    20. from torchvision.transforms.functional import resize
    21. from torchvision.utils import flow_to_image
    22. from tqdm import tqdm
    23. from utils.flow_utils import process_traj
    24. from utils.misc import vis_tensor
    25. from sat import mpu
    26. from sat.arguments import set_random_seed
    27. from sat.model.base_model import get_model
    28. from sat.training.model_io import load_checkpoint
    29. def read_from_cli():
    30. cnt = 0
    31. try:
    32. while True:
    33. x = input("Please input English text (Ctrl-D quit): ")
    34. yield x.strip(), cnt
    35. cnt += 1
    36. except EOFError as e:
    37. pass
    38. def read_from_file(p, rank=0, world_size=1):
    39. with open(p, "r") as fin:
    40. cnt = -1
    41. for l in fin:
    42. cnt += 1
    43. if cnt % world_size != rank:
    44. continue
    45. yield l.strip(), cnt
    46. def get_unique_embedder_keys_from_conditioner(conditioner):
    47. return list(set([x.input_key for x in conditioner.embedders]))
    48. def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    49. batch = {}
    50. batch_uc = {}
    51. for key in keys:
    52. if key == "txt":
    53. batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
    54. batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
    55. else:
    56. batch[key] = value_dict[key]
    57. if T is not None:
    58. batch["num_video_frames"] = T
    59. for key in batch.keys():
    60. if key not in batch_uc and isinstance(batch[key], torch.Tensor):
    61. batch_uc[key] = torch.clone(batch[key])
    62. return batch, batch_uc
    63. def draw_points(video, points):
    64. """
    65. Draw points onto video frames.
    66. Parameters:
    67. video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
    68. H is the height, W is the width, and C is the number of channels.
    69. points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],
    70. each point contains x and y coordinates.
    71. Returns:
    72. torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
    73. """
    74. T = video.shape[0]
    75. N = len(points)
    76. device = video.device
    77. dtype = video.dtype
    78. video = video.cpu().numpy().copy()
    79. traj = np.zeros(video.shape[-3:], dtype=np.uint8) # [H, W, C]
    80. for n in range(N):
    81. for t in range(1, T):
    82. cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)
    83. for t in range(T):
    84. mask = traj[..., -1] > 0
    85. mask = repeat(mask, "h w -> h w c", c=3)
    86. alpha = 0.7
    87. video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
    88. for n in range(N):
    89. cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)
    90. video = torch.from_numpy(video).to(device, dtype)
    91. return video
    92. def save_video_as_grid_and_mp4(
    93. video_batch: torch.Tensor,
    94. save_path: str,
    95. name: str,
    96. fps: int = 5,
    97. args=None,
    98. key=None,
    99. traj_points=None,
    100. prompt="",
    101. ):
    102. os.makedirs(save_path, exist_ok=True)
    103. p = Path(save_path)
    104. for i, vid in enumerate(video_batch):
    105. x = rearrange(vid, "t c h w -> t h w c")
    106. x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8) # [T H W C]
    107. os.makedirs(p / "video", exist_ok=True)
    108. os.makedirs(p / "prompt", exist_ok=True)
    109. if traj_points is not None:
    110. os.makedirs(p / "traj", exist_ok=True)
    111. os.makedirs(p / "traj_video", exist_ok=True)
    112. write_video(
    113. p / "video" / f"{name}_{i:06d}.mp4",
    114. x,
    115. fps=fps,
    116. video_codec="libx264",
    117. options={"crf": "18"},
    118. )
    119. with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:
    120. pickle.dump(traj_points, f)
    121. x = draw_points(x, traj_points)
    122. write_video(
    123. p / "traj_video" / f"{name}_{i:06d}.mp4",
    124. x,
    125. fps=fps,
    126. video_codec="libx264",
    127. options={"crf": "18"},
    128. )
    129. else:
    130. write_video(
    131. p / "video" / f"{name}_{i:06d}.mp4",
    132. x,
    133. fps=fps,
    134. video_codec="libx264",
    135. options={"crf": "18"},
    136. )
    137. with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:
    138. f.write(prompt)
    139. def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
    140. if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
    141. arr = resize(
    142. arr,
    143. size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
    144. interpolation=InterpolationMode.BICUBIC,
    145. )
    146. else:
    147. arr = resize(
    148. arr,
    149. size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
    150. interpolation=InterpolationMode.BICUBIC,
    151. )
    152. h, w = arr.shape[2], arr.shape[3]
    153. arr = arr.squeeze(0)
    154. delta_h = h - image_size[0]
    155. delta_w = w - image_size[1]
    156. if reshape_mode == "random" or reshape_mode == "none":
    157. top = np.random.randint(0, delta_h + 1)
    158. left = np.random.randint(0, delta_w + 1)
    159. elif reshape_mode == "center":
    160. top, left = delta_h // 2, delta_w // 2
    161. else:
    162. raise NotImplementedError
    163. arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
    164. return arr
    165. def sampling_main(args, model_cls):
    166. if isinstance(model_cls, type):
    167. model = get_model(args, model_cls)
    168. else:
    169. model = model_cls
    170. load_checkpoint(model, args)
    171. model.eval()
    172. if args.input_type == "cli":
    173. data_iter = read_from_cli()
    174. elif args.input_type == "txt":
    175. rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
    176. print("rank and world_size", rank, world_size)
    177. data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
    178. else:
    179. raise NotImplementedError
    180. image_size = [480, 720]
    181. sample_func = model.sample
    182. T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
    183. num_samples = [1]
    184. force_uc_zero_embeddings = ["txt"]
    185. device = model.device
    186. with torch.no_grad():
    187. for text, cnt in tqdm(data_iter):
    188. set_random_seed(args.seed)
    189. if args.flow_from_prompt:
    190. text, flow_files = text.split("\t")
    191. total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
    192. if args.no_flow_injection:
    193. video_flow = None
    194. elif args.flow_from_prompt:
    195. assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
    196. p = os.path.join(args.flow_path, flow_files)
    197. print(f"Flow path: {p}")
    198. video_flow = (
    199. torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
    200. )
    201. elif args.flow_path:
    202. print(f"Flow path: {args.flow_path}")
    203. video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
    204. :total_num_frames
    205. ].unsqueeze_(0)
    206. elif args.point_path:
    207. if type(args.point_path) == str:
    208. args.point_path = json.loads(args.point_path)
    209. print(f"Point path: {args.point_path}")
    210. video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
    211. video_flow = video_flow.unsqueeze_(0)
    212. else:
    213. print("No flow injection")
    214. video_flow = None
    215. if video_flow is not None:
    216. model.to("cpu") # move model to cpu, run vae on gpu only.
    217. tmp = rearrange(video_flow[0], "T H W C -> T C H W")
    218. video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
    219. if args.vis_traj_features:
    220. os.makedirs("samples/flow", exist_ok=True)
    221. vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
    222. imageio.mimwrite(
    223. "samples/flow/flow2_vis.gif",
    224. rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
    225. fps=8,
    226. loop=0,
    227. )
    228. del tmp
    229. video_flow = (
    230. rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
    231. )
    232. torch.cuda.empty_cache()
    233. video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
    234. model.first_stage_model.to(device)
    235. video_flow = model.encode_first_stage(video_flow, None)
    236. video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
    237. model.to(device)
    238. print("rank:", rank, "start to process", text, cnt)
    239. # TODO: broadcast image2video
    240. value_dict = {
    241. "prompt": text,
    242. "negative_prompt": "",
    243. "num_frames": torch.tensor(T).unsqueeze(0),
    244. }
    245. batch, batch_uc = get_batch(
    246. get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
    247. )
    248. for key in batch:
    249. if isinstance(batch[key], torch.Tensor):
    250. print(key, batch[key].shape)
    251. elif isinstance(batch[key], list):
    252. print(key, [len(l) for l in batch[key]])
    253. else:
    254. print(key, batch[key])
    255. c, uc = model.conditioner.get_unconditional_conditioning(
    256. batch,
    257. batch_uc=batch_uc,
    258. force_uc_zero_embeddings=force_uc_zero_embeddings,
    259. )
    260. for k in c:
    261. if not k == "crossattn":
    262. c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
    263. for index in range(args.batch_size):
    264. # reload model on GPU
    265. model.to(device)
    266. samples_z = sample_func(
    267. c,
    268. uc=uc,
    269. batch_size=1,
    270. shape=(T, C, H // F, W // F),
    271. video_flow=video_flow,
    272. )
    273. samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
    274. # Unload the model from GPU to save GPU memory
    275. model.to("cpu")
    276. torch.cuda.empty_cache()
    277. first_stage_model = model.first_stage_model
    278. first_stage_model = first_stage_model.to(device)
    279. latent = 1.0 / model.scale_factor * samples_z
    280. # Decode latent serial to save GPU memory
    281. recons = []
    282. loop_num = (T - 1) // 2
    283. for i in range(loop_num):
    284. if i == 0:
    285. start_frame, end_frame = 0, 3
    286. else:
    287. start_frame, end_frame = i * 2 + 1, i * 2 + 3
    288. if i == loop_num - 1:
    289. clear_fake_cp_cache = True
    290. else:
    291. clear_fake_cp_cache = False
    292. with torch.no_grad():
    293. recon = first_stage_model.decode(
    294. latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
    295. )
    296. recons.append(recon)
    297. recon = torch.cat(recons, dim=2).to(torch.float32)
    298. samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
    299. samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    300. save_path = args.output_dir
    301. name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
    302. if args.flow_from_prompt:
    303. name = Path(flow_files).stem
    304. if mpu.get_model_parallel_rank() == 0:
    305. save_video_as_grid_and_mp4(
    306. samples,
    307. save_path,
    308. name,
    309. fps=args.sampling_fps,
    310. traj_points=locals().get("points", None),
    311. prompt=text,
    312. )
    313. del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
    314. gc.collect()
    315. if __name__ == "__main__":
    316. if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
    317. os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
    318. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
    319. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    320. py_parser = argparse.ArgumentParser(add_help=False)
    321. known, args_list = py_parser.parse_known_args()
    322. args = get_args(args_list)
    323. args = argparse.Namespace(**vars(args), **vars(known))
    324. del args.deepspeed_config
    325. args.model_config.first_stage_config.params.cp_size = 1
    326. args.model_config.network_config.params.transformer_args.model_parallel_size = 1
    327. args.model_config.network_config.params.transformer_args.checkpoint_activations = False
    328. args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
    329. args.model_config.en_and_decode_n_samples_a_time = 1
    330. sampling_main(args, model_cls=SATVideoDiffusionEngine)

     未完......

    更多详细的欢迎关注:杰哥新技术

  • 相关阅读:
    【多媒体技术与实践】使用OpenCV处理图像(实验三.上)
    计算机网络技术:关于子网划分
    VMware Ubuntu虚拟机忘记密码
    【华为云云耀云服务器L实例评测|使用教学】一文带你快速入手华为云云耀云服务器L实例
    微信小程序更改AI类目-深度合成-AI绘画/AI问答教程
    进程相关介绍(二)
    用DIV+CSS技术设计的数码购物商城网站(web前端网页制作课作业)
    代码随想录二刷Day 51
    计算机专业毕业论文python毕业设计题目推荐基于Python实现的数据分析系统[包运行成功]
    哭了,我终于熬出头了,Java开发4年,费时8个月,入职阿里,涨薪14K
  • 原文地址:https://blog.csdn.net/m0_71062934/article/details/143273678