diff --git a/unimatch/unimatch.py b/unimatch/unimatch.py index 96db16e..ac2618e 100755 --- a/unimatch/unimatch.py +++ b/unimatch/unimatch.py @@ -12,8 +12,11 @@ from .reg_refine import BasicUpdateBlock from .utils import normalize_img, feature_add_position, upsample_flow_with_mask +from huggingface_hub import PyTorchModelHubMixin -class UniMatch(nn.Module): + +class UniMatch(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/autonomousvision/unimatch", + pipeline_tag="any-to-any", tags=["depth-estimation", "optical-flow-estimation", "disparity-estimation"], license="mit") def __init__(self, num_scales=1, feature_channels=128,