-
Notifications
You must be signed in to change notification settings - Fork 173
[Training] [0/n] Add preprocessing pipeline #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Need to merge #438 first, because this PR requires v1/datasets |
93fd4ac
to
5c3ec38
Compare
local_dir=os.path.join( | ||
'data', BASE_MODEL_PATH)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why we should use local_dir
--this can make cache sharing more complicated. I've PRed to remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it may have been me that added the hardcoded path? @JerryZhou54 perhaps lets just use model_path
arg here and have the registry detect the correct pipeline config by directly using the HF string instead of a local path?
# export WANDB_MODE="offline" | ||
GPU_NUM=1 # 2,4,8 | ||
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | ||
TEXT_ENCODER_PATH="/Wan-AI/Wan2.1-T2V-1.3B-Diffusers/tokenizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of hardcoding this TEXT_ENCODER_PATH, we can simply do:
path = maybe_download_model(args.model_path)
encoder_path = os.join(path, 'tokenizer')
|
||
logger = init_logger(__name__) | ||
|
||
BASE_MODEL_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just use args.model_path here instead?
feel free to merge after addressing my comments |
3e9c40a
to
31c34d0
Compare
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
# dataset & dataloader | ||
parser.add_argument("--model_path", type=str, default="data/mochi") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's better to not use data
except only for testers running on Runpod, just use hf's default cache path
Uh oh!
There was an error while loading. Please reload this page.