Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。
Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。
通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。
Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。
Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。
github项目地址:https://github.com/alibaba/Tora。
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118
cd modules/SwissArmyTransformer
pip install -e .
cd ../../sat
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3、CogVideoX-5b模型下载:
git lfs install
git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git
4、Tora t2v模型下载:
https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt
1、运行测试:
(1)python代码调用测试
- import argparse
- import gc
- import json
- import math
- import os
- import pickle
- from pathlib import Path
- from typing import List, Union
-
- import cv2
- import imageio
- import numpy as np
- import torch
- import torchvision.transforms as TT
- from arguments import get_args
- from diffusion_video import SATVideoDiffusionEngine
- from einops import rearrange, repeat
- from omegaconf import ListConfig
- from torchvision.io import write_video
- from torchvision.transforms import InterpolationMode
- from torchvision.transforms.functional import resize
- from torchvision.utils import flow_to_image
- from tqdm import tqdm
- from utils.flow_utils import process_traj
- from utils.misc import vis_tensor
-
- from sat import mpu
- from sat.arguments import set_random_seed
- from sat.model.base_model import get_model
- from sat.training.model_io import load_checkpoint
-
- def read_from_cli():
- cnt = 0
- try:
- while True:
- x = input("Please input English text (Ctrl-D quit): ")
- yield x.strip(), cnt
- cnt += 1
- except EOFError as e:
- pass
-
- def read_from_file(p, rank=0, world_size=1):
- with open(p, "r") as fin:
- cnt = -1
- for l in fin:
- cnt += 1
- if cnt % world_size != rank:
- continue
- yield l.strip(), cnt
-
- def get_unique_embedder_keys_from_conditioner(conditioner):
- return list(set([x.input_key for x in conditioner.embedders]))
-
- def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
- batch = {}
- batch_uc = {}
-
- for key in keys:
- if key == "txt":
- batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
- batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
- else:
- batch[key] = value_dict[key]
-
- if T is not None:
- batch["num_video_frames"] = T
-
- for key in batch.keys():
- if key not in batch_uc and isinstance(batch[key], torch.Tensor):
- batch_uc[key] = torch.clone(batch[key])
- return batch, batch_uc
-
- def draw_points(video, points):
- """
- Draw points onto video frames.
- Parameters:
- video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
- H is the height, W is the width, and C is the number of channels.
- points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],
- each point contains x and y coordinates.
- Returns:
- torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
- """
-
- T = video.shape[0]
- N = len(points)
- device = video.device
- dtype = video.dtype
- video = video.cpu().numpy().copy()
- traj = np.zeros(video.shape[-3:], dtype=np.uint8) # [H, W, C]
- for n in range(N):
- for t in range(1, T):
- cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)
- for t in range(T):
- mask = traj[..., -1] > 0
- mask = repeat(mask, "h w -> h w c", c=3)
- alpha = 0.7
- video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
- for n in range(N):
- cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)
- video = torch.from_numpy(video).to(device, dtype)
- return video
-
- def save_video_as_grid_and_mp4(
- video_batch: torch.Tensor,
- save_path: str,
- name: str,
- fps: int = 5,
- args=None,
- key=None,
- traj_points=None,
- prompt="",
- ):
- os.makedirs(save_path, exist_ok=True)
- p = Path(save_path)
-
- for i, vid in enumerate(video_batch):
- x = rearrange(vid, "t c h w -> t h w c")
- x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8) # [T H W C]
- os.makedirs(p / "video", exist_ok=True)
- os.makedirs(p / "prompt", exist_ok=True)
- if traj_points is not None:
- os.makedirs(p / "traj", exist_ok=True)
- os.makedirs(p / "traj_video", exist_ok=True)
- write_video(
- p / "video" / f"{name}_{i:06d}.mp4",
- x,
- fps=fps,
- video_codec="libx264",
- options={"crf": "18"},
- )
- with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:
- pickle.dump(traj_points, f)
- x = draw_points(x, traj_points)
- write_video(
- p / "traj_video" / f"{name}_{i:06d}.mp4",
- x,
- fps=fps,
- video_codec="libx264",
- options={"crf": "18"},
- )
- else:
- write_video(
- p / "video" / f"{name}_{i:06d}.mp4",
- x,
- fps=fps,
- video_codec="libx264",
- options={"crf": "18"},
- )
- with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:
- f.write(prompt)
-
- def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
- if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
- arr = resize(
- arr,
- size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
- interpolation=InterpolationMode.BICUBIC,
- )
- else:
- arr = resize(
- arr,
- size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
- interpolation=InterpolationMode.BICUBIC,
- )
-
- h, w = arr.shape[2], arr.shape[3]
- arr = arr.squeeze(0)
-
- delta_h = h - image_size[0]
- delta_w = w - image_size[1]
-
- if reshape_mode == "random" or reshape_mode == "none":
- top = np.random.randint(0, delta_h + 1)
- left = np.random.randint(0, delta_w + 1)
- elif reshape_mode == "center":
- top, left = delta_h // 2, delta_w // 2
- else:
- raise NotImplementedError
- arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
- return arr
-
- def sampling_main(args, model_cls):
- if isinstance(model_cls, type):
- model = get_model(args, model_cls)
- else:
- model = model_cls
-
- load_checkpoint(model, args)
- model.eval()
-
- if args.input_type == "cli":
- data_iter = read_from_cli()
- elif args.input_type == "txt":
- rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
- print("rank and world_size", rank, world_size)
- data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
- else:
- raise NotImplementedError
-
- image_size = [480, 720]
-
- sample_func = model.sample
- T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
- num_samples = [1]
- force_uc_zero_embeddings = ["txt"]
- device = model.device
- with torch.no_grad():
- for text, cnt in tqdm(data_iter):
- set_random_seed(args.seed)
- if args.flow_from_prompt:
- text, flow_files = text.split("\t")
- total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
- if args.no_flow_injection:
- video_flow = None
- elif args.flow_from_prompt:
- assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
- p = os.path.join(args.flow_path, flow_files)
- print(f"Flow path: {p}")
- video_flow = (
- torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
- )
- elif args.flow_path:
- print(f"Flow path: {args.flow_path}")
- video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
- :total_num_frames
- ].unsqueeze_(0)
- elif args.point_path:
- if type(args.point_path) == str:
- args.point_path = json.loads(args.point_path)
- print(f"Point path: {args.point_path}")
- video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
- video_flow = video_flow.unsqueeze_(0)
- else:
- print("No flow injection")
- video_flow = None
-
- if video_flow is not None:
- model.to("cpu") # move model to cpu, run vae on gpu only.
- tmp = rearrange(video_flow[0], "T H W C -> T C H W")
- video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
- if args.vis_traj_features:
- os.makedirs("samples/flow", exist_ok=True)
- vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
- imageio.mimwrite(
- "samples/flow/flow2_vis.gif",
- rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
- fps=8,
- loop=0,
- )
- del tmp
- video_flow = (
- rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
- )
- torch.cuda.empty_cache()
- video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
- model.first_stage_model.to(device)
- video_flow = model.encode_first_stage(video_flow, None)
- video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
- model.to(device)
-
- print("rank:", rank, "start to process", text, cnt)
- # TODO: broadcast image2video
- value_dict = {
- "prompt": text,
- "negative_prompt": "",
- "num_frames": torch.tensor(T).unsqueeze(0),
- }
-
- batch, batch_uc = get_batch(
- get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
- )
- for key in batch:
- if isinstance(batch[key], torch.Tensor):
- print(key, batch[key].shape)
- elif isinstance(batch[key], list):
- print(key, [len(l) for l in batch[key]])
- else:
- print(key, batch[key])
- c, uc = model.conditioner.get_unconditional_conditioning(
- batch,
- batch_uc=batch_uc,
- force_uc_zero_embeddings=force_uc_zero_embeddings,
- )
-
- for k in c:
- if not k == "crossattn":
- c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
- for index in range(args.batch_size):
- # reload model on GPU
- model.to(device)
- samples_z = sample_func(
- c,
- uc=uc,
- batch_size=1,
- shape=(T, C, H // F, W // F),
- video_flow=video_flow,
- )
- samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
-
- # Unload the model from GPU to save GPU memory
- model.to("cpu")
- torch.cuda.empty_cache()
- first_stage_model = model.first_stage_model
- first_stage_model = first_stage_model.to(device)
-
- latent = 1.0 / model.scale_factor * samples_z
-
- # Decode latent serial to save GPU memory
- recons = []
- loop_num = (T - 1) // 2
- for i in range(loop_num):
- if i == 0:
- start_frame, end_frame = 0, 3
- else:
- start_frame, end_frame = i * 2 + 1, i * 2 + 3
- if i == loop_num - 1:
- clear_fake_cp_cache = True
- else:
- clear_fake_cp_cache = False
- with torch.no_grad():
- recon = first_stage_model.decode(
- latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
- )
-
- recons.append(recon)
-
- recon = torch.cat(recons, dim=2).to(torch.float32)
- samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
-
- save_path = args.output_dir
- name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
- if args.flow_from_prompt:
- name = Path(flow_files).stem
- if mpu.get_model_parallel_rank() == 0:
- save_video_as_grid_and_mp4(
- samples,
- save_path,
- name,
- fps=args.sampling_fps,
- traj_points=locals().get("points", None),
- prompt=text,
- )
- del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
- gc.collect()
-
- if __name__ == "__main__":
- if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
- os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
- os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
- os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
- py_parser = argparse.ArgumentParser(add_help=False)
- known, args_list = py_parser.parse_known_args()
-
- args = get_args(args_list)
- args = argparse.Namespace(**vars(args), **vars(known))
- del args.deepspeed_config
- args.model_config.first_stage_config.params.cp_size = 1
- args.model_config.network_config.params.transformer_args.model_parallel_size = 1
- args.model_config.network_config.params.transformer_args.checkpoint_activations = False
- args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
- args.model_config.en_and_decode_n_samples_a_time = 1
-
- sampling_main(args, model_cls=SATVideoDiffusionEngine)
未完......
更多详细的欢迎关注:杰哥新技术