Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:

- name: Run Mypy
# https://github.com/pytorch/ignite/pull/2780
#
#
if: ${{ matrix.os == 'ubuntu-latest' && matrix.pytorch-channel == 'pytorch-nightly'}}
run: |
bash ./tests/run_code_style.sh mypy
Expand Down Expand Up @@ -189,3 +189,8 @@ jobs:
run: |
# Super-Resolution
python examples/super_resolution/main.py --upscale_factor 3 --crop_size 180 --batch_size 4 --test_batch_size 100 --n_epochs 1 --lr 0.001 --threads 2 --debug
- name: Run Siamese Network Example
if: ${{ matrix.os == 'ubuntu-latest' }}
run: |
# Siamese Network
python examples/siamese_network/siamese_network.py --batch-size 256 --test-batch-size 256 --epochs 4 --lr 0.95 --gamma 0.97 --num-workers 5
31 changes: 27 additions & 4 deletions examples/siamese_network/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
# Siamese Network example on MNIST dataset
# Siamese Network example on CIFAR10 dataset

This example is ported over from [pytorch/examples/siamese_network](https://github.com/pytorch/examples/tree/main/siamese_network)
This example is inspired from [pytorch/examples/siamese_network](https://github.com/pytorch/examples/tree/main/siamese_network). It illustrates the implementation of Siamese Network for checking image similarity in CIFAR10 dataset.

Usage:
## Usage:

```
pip install -r requirements.txt
python siamese_network.py
python siamese_network.py [-h] [--batch-size BATCHSIZE] [--test-batch-size TESTBATCHSIZE] [--epochs EPOCHS]
[--lr LEARNINGRATE] [--gamma GAMMA] [--no-cuda][--no-mps] [--dry-run]
[--seed SEED] [--log-interval LOGINTERVAL] [--save-model] [--num-workers NUMWORKERS]

optional arguments:
-h, --help shows usage and exits
--batch-size sets training batch size
--test-batch-size sets testing batch size
--epochs sets number of training epochs
--lr sets learning rate
--gamma sets gamma parameter for LR Scheduler
--no-cuda disables CUDA training
--no-mps disables macOS GPU training
--dry-run runs model over a single pass
--seed sets random seed
--log-interval sets number of epochs before logging results
--save-model saves current model
--num-workers sets number of processes generating parallel batches
```

## Example Usage:

```
python siamese_network.py --batch-size 64 --test-batch-size 256 --epochs 14 --lr 0.95 --gamma 0.97 --num-workers 5
```
22 changes: 15 additions & 7 deletions examples/siamese_network/siamese_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ def main():
# adds training defaults and support for terminal arguments
parser = argparse.ArgumentParser(description="PyTorch Siamese network Example")
parser.add_argument(
"--batch-size", type=int, default=256, metavar="N", help="input batch size for training (default: 64)"
"--batch-size", type=int, default=256, metavar="N", help="input batch size for training (default: 256)"
)
parser.add_argument(
"--test-batch-size", type=int, default=256, metavar="N", help="input batch size for testing (default: 1000)"
"--test-batch-size", type=int, default=256, metavar="N", help="input batch size for testing (default: 256)"
)
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 14)")
parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)")
parser.add_argument("--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)")
parser.add_argument(
"--gamma", type=float, default=0.95, metavar="M", help="Learning rate step gamma (default: 0.7)"
"--gamma", type=float, default=0.95, metavar="M", help="Learning rate step gamma (default: 0.95)"
)
parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training")
parser.add_argument("--no-mps", action="store_true", default=False, help="disables macOS GPU training")
Expand All @@ -281,15 +281,23 @@ def main():
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument("--save-model", action="store_true", default=False, help="For Saving the current Model")
parser.add_argument("--num-workers", default=4, help="number of processes generating parallel batches")
parser.add_argument("--save-model", action="store_true", default=False, help="saves model parameters")
parser.add_argument("--num-workers", type=int, default=4, help="number of parallel batches (default: 4)")
args = parser.parse_args()

# set manual seed
manual_seed(args.seed)

# set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

# data loading
train_dataset = MatcherDataset("../data", train=True, download=True)
Expand Down