diff --git a/README.md b/README.md index 63571ba..44688d1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,15 @@ Conference ## Description What it does +## Use From Torch Hub +```python +import torch +model = torch.hub.load( + "pytorchlightning/deep-learning-project-template", + "lit_classifier", +) +``` + ## How to run First, install dependencies ```bash diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..aa8f256 --- /dev/null +++ b/hubconf.py @@ -0,0 +1,8 @@ +from project.lit_mnist import LitClassifier + + +# any function defined here will be able to load a model like: +# torch.hub.load("username/repo", "lit_classifier", *args, **kwargs) +# can put logic here for loading pretrained weights +def lit_classifier(*args, **kwargs): + return LitClassifier(*args, **kwargs)