@@ -249,6 +249,7 @@ def cli():
249249 parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
250250 parser .add_argument ("audio" , nargs = "+" , type = str , help = "audio file(s) to transcribe" )
251251 parser .add_argument ("--model" , default = "small" , choices = available_models (), help = "name of the Whisper model to use" )
252+ parser .add_argument ("--model_dir" , type = str , default = None , help = "the path to save model files; uses ~/.cache/whisper by default" )
252253 parser .add_argument ("--device" , default = "cuda" if torch .cuda .is_available () else "cpu" , help = "device to use for PyTorch inference" )
253254 parser .add_argument ("--output_dir" , "-o" , type = str , default = "." , help = "directory to save the outputs" )
254255 parser .add_argument ("--verbose" , type = str2bool , default = True , help = "whether to print out the progress and debug messages" )
@@ -274,6 +275,7 @@ def cli():
274275
275276 args = parser .parse_args ().__dict__
276277 model_name : str = args .pop ("model" )
278+ model_dir : str = args .pop ("model_dir" )
277279 output_dir : str = args .pop ("output_dir" )
278280 device : str = args .pop ("device" )
279281 os .makedirs (output_dir , exist_ok = True )
@@ -290,7 +292,7 @@ def cli():
290292 temperature = [temperature ]
291293
292294 from . import load_model
293- model = load_model (model_name , device = device )
295+ model = load_model (model_name , device = device , download_root = model_dir )
294296
295297 for audio_path in args .pop ("audio" ):
296298 result = transcribe (model , audio_path , temperature = temperature , ** args )
0 commit comments