Skip to content

Commit 93fd4ac

Browse files
committed
Add preprocessing
Co-authored-by: Yongqi Chen <[email protected]>
1 parent 6ef8fcb commit 93fd4ac

File tree

4 files changed

+724
-4
lines changed

4 files changed

+724
-4
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import argparse
2+
import json
3+
import os
4+
5+
import torch
6+
import torch.distributed as dist
7+
8+
from fastvideo.v1.logger import init_logger
9+
from fastvideo.v1.utils import maybe_download_model, shallow_asdict
10+
from fastvideo.v1.distributed import init_distributed_environment, initialize_model_parallel
11+
from fastvideo.v1.fastvideo_args import FastVideoArgs
12+
from fastvideo.v1.configs.models.vaes import WanVAEConfig
13+
from fastvideo import PipelineConfig
14+
from fastvideo.v1.pipelines.preprocess_pipeline import PreprocessPipeline
15+
16+
logger = init_logger(__name__)
17+
18+
BASE_MODEL_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
19+
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
20+
local_dir=os.path.join(
21+
'data', BASE_MODEL_PATH))
22+
23+
def main(args):
24+
# Assume using torchrun
25+
local_rank = int(os.getenv("RANK", 0))
26+
rank = int(os.environ.get("RANK", 0))
27+
world_size = int(os.getenv("WORLD_SIZE", 1))
28+
init_distributed_environment(world_size=world_size, rank=rank, local_rank=local_rank)
29+
initialize_model_parallel(tensor_model_parallel_size=world_size, sequence_model_parallel_size=world_size)
30+
torch.cuda.set_device(local_rank)
31+
if not dist.is_initialized():
32+
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=local_rank)
33+
34+
pipeline_config = PipelineConfig.from_pretrained(MODEL_PATH)
35+
kwargs = {
36+
"use_cpu_offload": False,
37+
"vae_precision": "fp32",
38+
"vae_config": WanVAEConfig(load_encoder=True, load_decoder=False),
39+
}
40+
pipeline_config_args = shallow_asdict(pipeline_config)
41+
pipeline_config_args.update(kwargs)
42+
fastvideo_args = FastVideoArgs(model_path=MODEL_PATH,
43+
num_gpus=world_size,
44+
device_str="cuda",
45+
**pipeline_config_args,
46+
)
47+
fastvideo_args.check_fastvideo_args()
48+
fastvideo_args.device = torch.device(f"cuda:{local_rank}")
49+
50+
pipeline = PreprocessPipeline(MODEL_PATH, fastvideo_args)
51+
pipeline.forward(batch=None, fastvideo_args=fastvideo_args, args=args)
52+
53+
54+
if __name__ == "__main__":
55+
parser = argparse.ArgumentParser()
56+
# dataset & dataloader
57+
parser.add_argument("--model_path", type=str, default="data/mochi")
58+
parser.add_argument("--model_type", type=str, default="mochi")
59+
parser.add_argument("--data_merge_path", type=str, required=True)
60+
parser.add_argument("--validation_prompt_txt", type=str)
61+
parser.add_argument("--num_frames", type=int, default=163)
62+
parser.add_argument(
63+
"--dataloader_num_workers",
64+
type=int,
65+
default=1,
66+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
67+
)
68+
parser.add_argument(
69+
"--preprocess_video_batch_size",
70+
type=int,
71+
default=2,
72+
help="Batch size (per device) for the training dataloader.",
73+
)
74+
parser.add_argument(
75+
"--preprocess_text_batch_size",
76+
type=int,
77+
default=8,
78+
help="Batch size (per device) for the training dataloader.",
79+
)
80+
parser.add_argument(
81+
"--samples_per_file",
82+
type=int,
83+
default=64
84+
)
85+
parser.add_argument(
86+
"--flush_frequency",
87+
type=int,
88+
default=256,
89+
help="how often to save to parquet files"
90+
)
91+
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
92+
parser.add_argument("--max_height", type=int, default=480)
93+
parser.add_argument("--max_width", type=int, default=848)
94+
parser.add_argument("--video_length_tolerance_range", type=int, default=2.0)
95+
parser.add_argument("--group_frame", action="store_true") # TODO
96+
parser.add_argument("--group_resolution", action="store_true") # TODO
97+
parser.add_argument("--dataset", default="t2v")
98+
parser.add_argument("--train_fps", type=int, default=30)
99+
parser.add_argument("--use_image_num", type=int, default=0)
100+
parser.add_argument("--text_max_length", type=int, default=256)
101+
parser.add_argument("--speed_factor", type=float, default=1.0)
102+
parser.add_argument("--drop_short_ratio", type=float, default=1.0)
103+
# text encoder & vae & diffusion model
104+
parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl")
105+
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
106+
parser.add_argument("--cfg", type=float, default=0.0)
107+
parser.add_argument(
108+
"--output_dir",
109+
type=str,
110+
default=None,
111+
help="The output directory where the model predictions and checkpoints will be written.",
112+
)
113+
parser.add_argument(
114+
"--logging_dir",
115+
type=str,
116+
default="logs",
117+
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
118+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
119+
)
120+
121+
args = parser.parse_args()
122+
main(args)

0 commit comments

Comments
 (0)