1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ import torch .distributed as dist
7
+
8
+ from fastvideo .v1 .logger import init_logger
9
+ from fastvideo .v1 .utils import maybe_download_model , shallow_asdict
10
+ from fastvideo .v1 .distributed import init_distributed_environment , initialize_model_parallel
11
+ from fastvideo .v1 .fastvideo_args import FastVideoArgs
12
+ from fastvideo .v1 .configs .models .vaes import WanVAEConfig
13
+ from fastvideo import PipelineConfig
14
+ from fastvideo .v1 .pipelines .preprocess_pipeline import PreprocessPipeline
15
+
16
+ logger = init_logger (__name__ )
17
+
18
+ BASE_MODEL_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
19
+ MODEL_PATH = maybe_download_model (BASE_MODEL_PATH ,
20
+ local_dir = os .path .join (
21
+ 'data' , BASE_MODEL_PATH ))
22
+
23
+ def main (args ):
24
+ # Assume using torchrun
25
+ local_rank = int (os .getenv ("RANK" , 0 ))
26
+ rank = int (os .environ .get ("RANK" , 0 ))
27
+ world_size = int (os .getenv ("WORLD_SIZE" , 1 ))
28
+ init_distributed_environment (world_size = world_size , rank = rank , local_rank = local_rank )
29
+ initialize_model_parallel (tensor_model_parallel_size = world_size , sequence_model_parallel_size = world_size )
30
+ torch .cuda .set_device (local_rank )
31
+ if not dist .is_initialized ():
32
+ dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
33
+
34
+ pipeline_config = PipelineConfig .from_pretrained (MODEL_PATH )
35
+ kwargs = {
36
+ "use_cpu_offload" : False ,
37
+ "vae_precision" : "fp32" ,
38
+ "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
39
+ }
40
+ pipeline_config_args = shallow_asdict (pipeline_config )
41
+ pipeline_config_args .update (kwargs )
42
+ fastvideo_args = FastVideoArgs (model_path = MODEL_PATH ,
43
+ num_gpus = world_size ,
44
+ device_str = "cuda" ,
45
+ ** pipeline_config_args ,
46
+ )
47
+ fastvideo_args .check_fastvideo_args ()
48
+ fastvideo_args .device = torch .device (f"cuda:{ local_rank } " )
49
+
50
+ pipeline = PreprocessPipeline (MODEL_PATH , fastvideo_args )
51
+ pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
52
+
53
+
54
+ if __name__ == "__main__" :
55
+ parser = argparse .ArgumentParser ()
56
+ # dataset & dataloader
57
+ parser .add_argument ("--model_path" , type = str , default = "data/mochi" )
58
+ parser .add_argument ("--model_type" , type = str , default = "mochi" )
59
+ parser .add_argument ("--data_merge_path" , type = str , required = True )
60
+ parser .add_argument ("--validation_prompt_txt" , type = str )
61
+ parser .add_argument ("--num_frames" , type = int , default = 163 )
62
+ parser .add_argument (
63
+ "--dataloader_num_workers" ,
64
+ type = int ,
65
+ default = 1 ,
66
+ help = "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ,
67
+ )
68
+ parser .add_argument (
69
+ "--preprocess_video_batch_size" ,
70
+ type = int ,
71
+ default = 2 ,
72
+ help = "Batch size (per device) for the training dataloader." ,
73
+ )
74
+ parser .add_argument (
75
+ "--preprocess_text_batch_size" ,
76
+ type = int ,
77
+ default = 8 ,
78
+ help = "Batch size (per device) for the training dataloader." ,
79
+ )
80
+ parser .add_argument (
81
+ "--samples_per_file" ,
82
+ type = int ,
83
+ default = 64
84
+ )
85
+ parser .add_argument (
86
+ "--flush_frequency" ,
87
+ type = int ,
88
+ default = 256 ,
89
+ help = "how often to save to parquet files"
90
+ )
91
+ parser .add_argument ("--num_latent_t" , type = int , default = 28 , help = "Number of latent timesteps." )
92
+ parser .add_argument ("--max_height" , type = int , default = 480 )
93
+ parser .add_argument ("--max_width" , type = int , default = 848 )
94
+ parser .add_argument ("--video_length_tolerance_range" , type = int , default = 2.0 )
95
+ parser .add_argument ("--group_frame" , action = "store_true" ) # TODO
96
+ parser .add_argument ("--group_resolution" , action = "store_true" ) # TODO
97
+ parser .add_argument ("--dataset" , default = "t2v" )
98
+ parser .add_argument ("--train_fps" , type = int , default = 30 )
99
+ parser .add_argument ("--use_image_num" , type = int , default = 0 )
100
+ parser .add_argument ("--text_max_length" , type = int , default = 256 )
101
+ parser .add_argument ("--speed_factor" , type = float , default = 1.0 )
102
+ parser .add_argument ("--drop_short_ratio" , type = float , default = 1.0 )
103
+ # text encoder & vae & diffusion model
104
+ parser .add_argument ("--text_encoder_name" , type = str , default = "google/t5-v1_1-xxl" )
105
+ parser .add_argument ("--cache_dir" , type = str , default = "./cache_dir" )
106
+ parser .add_argument ("--cfg" , type = float , default = 0.0 )
107
+ parser .add_argument (
108
+ "--output_dir" ,
109
+ type = str ,
110
+ default = None ,
111
+ help = "The output directory where the model predictions and checkpoints will be written." ,
112
+ )
113
+ parser .add_argument (
114
+ "--logging_dir" ,
115
+ type = str ,
116
+ default = "logs" ,
117
+ help = ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
118
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ),
119
+ )
120
+
121
+ args = parser .parse_args ()
122
+ main (args )
0 commit comments