Skip to content

Commit 3cf2fb5

Browse files
author
WenkelF
committed
Updating get_checkpoint_path
1 parent 8f1ddfb commit 3cf2fb5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

graphium/config/_loader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -589,15 +589,15 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->
589589

590590
cfg_trainer = config["trainer"]
591591

592-
if "model_checkpoint" in cfg_trainer.keys():
593-
dirpath = cfg_trainer["model_checkpoint"]["dirpath"] + str(cfg_trainer["seed"]) + "/"
594-
filename = config.get("ckpt_name_for_testing", "last") + ".ckpt"
595-
else:
596-
raise ValueError("Empty checkpoint section in config file")
592+
path = config.get("ckpt_name_for_testing", "last.ckpt")
593+
if path in GRAPHIUM_PRETRAINED_MODELS_DICT or fs.exists(path):
594+
return path
597595

598-
checkpoint_path = fs.join(dirpath, filename)
596+
if "model_checkpoint" in cfg_trainer.keys():
597+
dirpath = cfg_trainer["model_checkpoint"]["dirpath"]
598+
path = fs.join(dirpath, path)
599599

600-
if not fs.exists(checkpoint_path):
601-
raise ValueError(f"Checkpoint path `{checkpoint_path}` does not exist")
600+
if not fs.exists(path):
601+
raise ValueError(f"Checkpoint path `{path}` does not exist")
602602

603-
return checkpoint_path
603+
return path

0 commit comments

Comments
 (0)