diff --git a/config/TPN.yaml b/config/TPN.yaml new file mode 100644 index 00000000..9c4ff735 --- /dev/null +++ b/config/TPN.yaml @@ -0,0 +1,39 @@ +backbone: + name: Conv64F + kwargs: + is_flatten: false + is_feature: false + leaky_relu: false + negative_slope: 0.2 + last_pool: true + maxpool_last2: true + use_running_statistics: true + +classifier: + name: TPN + kwargs: + topk: 20 + sigma: 0.25 + alpha: 0.99 + rn: 300 + +way_num: 5 +shot_num: 1 +query_num: 15 + +epoch: 1000 +test_epoch: 100 +train_episode: 100 +test_episode: 100 +episode_size: 5 + +optimizer: + name: Adam + kwargs: + lr: 1e-3 + +lr_scheduler: + name: StepLR + kwargs: + step_size: 10000 + gamma: 0.5 diff --git a/core/model/backbone/__init__.py b/core/model/backbone/__init__.py index 47aa8388..19077ab9 100644 --- a/core/model/backbone/__init__.py +++ b/core/model/backbone/__init__.py @@ -11,7 +11,6 @@ from .resnet_bdc import resnet12Bdc, resnet18Bdc from core.model.backbone.utils.maml_module import convert_maml_module - def get_backbone(config): """Get the backbone according to the config dict. diff --git a/core/model/metric/__init__.py b/core/model/metric/__init__.py index 459d042e..1f97fa09 100644 --- a/core/model/metric/__init__.py +++ b/core/model/metric/__init__.py @@ -13,4 +13,5 @@ from .deepbdc import DeepBDC from .frn import FRN from .meta_baseline import MetaBaseline -from .mcl import MCL \ No newline at end of file +from .mcl import MCL +from .tpn import TPN \ No newline at end of file diff --git a/core/model/metric/tpn.py b/core/model/metric/tpn.py new file mode 100644 index 00000000..c7c62aa5 --- /dev/null +++ b/core/model/metric/tpn.py @@ -0,0 +1,197 @@ +""" +@misc{liu2019learningpropagatelabelstransductive, + title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning}, + author={Yanbin Liu and Juho Lee and Minseop Park and Saehoon Kim and Eunho Yang and Sung Ju Hwang and Yi Yang}, + year={2019}, + eprint={1805.10002}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1805.10002}, +} + +Adapted From https://github.com/csyanbin/TPN-pytorch +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .metric_model import MetricModel + + +class RelationNetwork(nn.Module): + """Graph Construction Module""" + + def __init__(self): + super(RelationNetwork, self).__init__() + + self.layer1 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, padding=1), + ) + self.layer2 = nn.Sequential( + nn.Conv2d(64, 1, kernel_size=3, padding=1), + nn.BatchNorm2d(1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, padding=1), + ) + + self.fc3 = nn.Linear(2 * 2, 8) + self.fc4 = nn.Linear(8, 1) + + def forward(self, x): + x = x.view(-1, 64, 5, 5) + + out = self.layer1(x) + out = self.layer2(out) + + out = out.view(out.size(0), -1) + out = F.relu(self.fc3(out)) + out = self.fc4(out) + + out = out.view(out.size(0), -1) + + return out + + +class TPN(MetricModel): + def __init__(self, alpha, **kwargs): + super().__init__(**kwargs) + + self.relation = RelationNetwork() + self.eps = torch.finfo(torch.float32).eps + + if self.rn == 300: + self.alpha = torch.tensor([alpha], requires_grad=False).to(self.device) + elif self.rn == 30: + self.alpha = nn.Parameter( + torch.tensor([alpha]).to(self.device), requires_grad=True + ) + + def labels_to_onehot(self, labels): + batch_size = labels.size(0) + one_hot = torch.zeros(batch_size, self.way_num).to(self.device) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + + return one_hot + + def label_propagation(self, support, query, support_label, query_label): + input_feat = torch.cat((support, query), 0) + embedding_all = self.emb_func(input_feat).view(-1, 1600) + num_nodes = embedding_all.shape[0] + + if self.rn in [30, 300]: + self.sigma = self.relation(embedding_all) + embedding_all = embedding_all / (self.sigma + self.eps) + + weight_matrix = torch.cdist(embedding_all, embedding_all, p=2) ** 2 + weight_matrix = torch.exp(-weight_matrix / 2) + + if self.topk > 0: + topk_values, topk_indices = torch.topk(weight_matrix, self.topk) + mask = torch.zeros_like(weight_matrix) + mask = mask.scatter(1, topk_indices, 1) + mask = (mask + mask.t()) > 0 + weight_matrix = weight_matrix * mask + + degree_matrix = weight_matrix.sum(0) + degree_sqrt_inv = torch.rsqrt(degree_matrix + self.eps) + + symmetric_matrix = weight_matrix * degree_sqrt_inv.unsqueeze(0) * degree_sqrt_inv.unsqueeze(1) + + support_labels = support_label + unlabeled_query = torch.zeros(self.way_num * self.query_num, self.way_num, + device=self.device, dtype=support_labels.dtype) + combined_labels = torch.cat((support_labels, unlabeled_query), 0) + + identity_matrix = torch.eye(num_nodes, device=self.device) + A = identity_matrix - self.alpha * symmetric_matrix + propagated_labels = torch.linalg.solve(A, combined_labels) + + query_predictions = propagated_labels[self.way_num * self.shot_num:, :] + + all_labels = torch.cat((support_label, query_label), 0) + ground_truth = torch.argmax(all_labels, 1) + criterion = nn.CrossEntropyLoss() + loss = criterion(propagated_labels, ground_truth) + + predicted_query = torch.argmax(query_predictions, 1) + ground_truth_query = torch.argmax(query_label, 1) + + accuracy = (predicted_query == ground_truth_query).float().mean() + + return loss, accuracy + + def set_forward_loss(self, batch): + image, global_target = batch + image = image.to(self.device) + + episode_size = image.size(0) // ( + self.way_num * (self.shot_num + self.query_num) + ) + + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + + loss_list = torch.zeros(episode_size, device=self.device) + acc_list = torch.zeros(episode_size, device=self.device) + + for i in range(episode_size): + support_label_onehot = self.labels_to_onehot(support_target[i]) + query_label_onehot = self.labels_to_onehot(query_target[i]) + + loss, acc = self.label_propagation( + support_image[i], + query_image[i], + support_label_onehot, + query_label_onehot, + ) + + loss_list[i] = loss + acc_list[i] = acc + + loss = loss_list.mean() + acc = acc_list.mean() * 100.0 + + return None, acc, loss + + def set_forward(self, batch): + image, global_target = batch + image = image.to(self.device) + + episode_size = image.size(0) // ( + self.way_num * (self.shot_num + self.query_num) + ) + + ( + support_image, + query_image, + support_target, + query_target, + ) = self.split_by_episode(image, mode=2) + + acc_list = torch.zeros(episode_size, device=self.device) + + for i in range(episode_size): + support_label_onehot = self.labels_to_onehot(support_target[i]) + query_label_onehot = self.labels_to_onehot(query_target[i]) + + _, acc = self.label_propagation( + support_image[i], + query_image[i], + support_label_onehot, + query_label_onehot, + ) + + acc_list[i] = acc + + acc = acc_list.mean() * 100.0 + + return None, acc diff --git a/core/test.py b/core/test.py index 0da3167a..79b78e23 100644 --- a/core/test.py +++ b/core/test.py @@ -193,7 +193,22 @@ def _init_files(self, config): rank=self.rank, ) - state_dict_path = os.path.join(result_path, "checkpoints", "model_best.pth") + checkpoint_type = config.get("checkpoint_type", "best") + if checkpoint_type == "best": + checkpoint_filename = "model_best.pth" + elif checkpoint_type == "last": + checkpoint_filename = "model_last.pth" + elif isinstance(checkpoint_type, int) or checkpoint_type.isdigit(): + epoch_num = int(checkpoint_type) + checkpoint_filename = f"model_{epoch_num:05d}.pth" + else: + print( + f"Warning: Invalid checkpoint_type '{checkpoint_type}', using 'best'", + level="warning", + ) + checkpoint_filename = "model_best.pth" + + state_dict_path = os.path.join(result_path, "checkpoints", checkpoint_filename) if self.rank == 0: create_dirs([result_path, log_path, viz_path]) diff --git a/reproduce/TPN/README.md b/reproduce/TPN/README.md new file mode 100644 index 00000000..367c5349 --- /dev/null +++ b/reproduce/TPN/README.md @@ -0,0 +1,45 @@ +# TPN Reproduction + +## Introduction + +| Name: | [TPN](https://arxiv.org/abs/1805.10002) | +| ------- | ---------------------------------------------------------- | +| Embed.: | Conv64F | +| Type: | Metric | +| Venue: | ICLR2019 | +| Codes: | [**TPN-pytorch**](https://github.com/csyanbin/TPN-pytorch) | + +Cite this work with (template): + +```bibtex +@misc{liu2019learningpropagatelabelstransductive, + title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning}, + author={Yanbin Liu and Juho Lee and Minseop Park and Saehoon Kim and Eunho Yang and Sung Ju Hwang and Yi Yang}, + year={2019}, + eprint={1805.10002}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/1805.10002}, +} +``` + +--- + +## Results and Models + +All the results are tested under the best model. Checkpoints of different epochs are also provided. + +**TPN Result** + +| dataset/task | 5way-1shot | 5way-5shot | 10way-1shot | 10-way-5shot | +| -------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| miniImageNet | 54.06±0.38 [:arrow_down:](https://drive.google.com/drive/folders/1Y14e-h_DcwfyxwU71GZIXQ2G39ZS8oys) | 69.27±0.30 [:arrow_down:](https://drive.google.com/drive/folders/1DIBmJ8a_GZIlEmUW0KEf_awTaB-8GB7i) | 37.36±0.23 [:arrow_down:](https://drive.google.com/drive/folders/1OeO3K7wY4y-UN979eRUQo1vvVvin3vF-) | 53.62±0.20 [:arrow_down:](https://drive.google.com/drive/folders/1c8yd0rMQhAytePcrfLq2nEYaxPDFHc8f) | +| tieredImageNet | 53.36±0.42 [:arrow_down:](https://drive.google.com/drive/folders/1_C0VA1LirJ5l3kYqEBY8HfHmXOzCkZog) | 69.83±0.35 [:arrow_down:](https://drive.google.com/drive/folders/1anwG8tjvaXQ5oq9BBjcTQjY9OdKmDYf1) | 40.29±0.28 [:arrow_down:](https://drive.google.com/drive/folders/1HnJfDHHE78YzaGvhmHHvaG4ymVhdPkrh) | 57.53±0.25 [:arrow_down:](https://drive.google.com/drive/folders/1NBaxyY60rJIxnoeOV0UAAt6KORPvZyfJ) | + +**Higher Shot TPN Result** + +| dataset/task | 5way-1shot | 5way-5shot | 10way-1shot | 10-way-5shot | +| -------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| miniImageNet | 55.37±0.39 [:arrow_down:](https://drive.google.com/drive/folders/1sMAcvy817oBzMybkO_HVQ0AMBhJW0kbr?usp=drive_link) | 68.80±0.30 [:arrow_down:](https://drive.google.com/drive/folders/1G1WMnHbsgJcgbdSWoVqRWkQJldDvbDL1?usp=drive_link) | 38.62±0.22 [:arrow_down:](https://drive.google.com/drive/folders/1aXKD2GBbsM7Ql0FdpZV1NJeEdu5QZMyV?usp=drive_link) | 53.68±0.21 [:arrow_down:](https://drive.google.com/drive/folders/1mi2r9Kbl6SvSdbn5kKRd3DaqS354vsxe?usp=drive_link) | +| tieredImageNet | 57.17±0.42 [:arrow_down:](https://drive.google.com/drive/folders/1G-zBq1Hp0zrPIgN1UUlOIeuXnusRJfmB?usp=drive_link) | 69.67±0.34 [:arrow_down:](https://drive.google.com/drive/folders/1-1e22Tjs1YtVuh6HujQ5hMCXWORlKVBR?usp=drive_link) | 43.40±0.28 [:arrow_down:](https://drive.google.com/drive/folders/1hY0q3Nd84SqBSTVmF1ZveKkxc5ZH5raf?usp=drive_link) | 57.71±0.25 [:arrow_down:](https://drive.google.com/drive/folders/1bwoHxwTBSBtPHOyv-0B1v2ZS2IMbAlO1?usp=drive_link) | + diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-10-1-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-10-1-Table1.yaml new file mode 100644 index 00000000..b3dd7d12 --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-10-1-Table1.yaml @@ -0,0 +1,26 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 10 +shot_num: 1 +query_num: 15 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-10-1-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-10-1-highershot-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-10-1-highershot-Table1.yaml new file mode 100644 index 00000000..44f6c676 --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-10-1-highershot-Table1.yaml @@ -0,0 +1,27 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 10 +shot_num: 5 +query_num: 15 +test_shot: 1 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-10-1-highershot-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-10-5-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-10-5-Table1.yaml new file mode 100644 index 00000000..b7ff1e8d --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-10-5-Table1.yaml @@ -0,0 +1,26 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 10 +shot_num: 5 +query_num: 15 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-10-5-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-10-5-highershot-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-10-5-highershot-Table1.yaml new file mode 100644 index 00000000..9d40db2e --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-10-5-highershot-Table1.yaml @@ -0,0 +1,27 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 10 +shot_num: 10 +query_num: 15 +test_shot: 5 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-10-5-highershot-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-5-1-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-5-1-Table1.yaml new file mode 100644 index 00000000..723fd6d5 --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-5-1-Table1.yaml @@ -0,0 +1,26 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 5 +shot_num: 1 +query_num: 15 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-5-1-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-5-1-highershot-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-5-1-highershot-Table1.yaml new file mode 100644 index 00000000..f68c56f0 --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-5-1-highershot-Table1.yaml @@ -0,0 +1,27 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 5 +shot_num: 5 +query_num: 15 +test_shot: 1 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-5-1-highershot-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-5-5-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-5-5-Table1.yaml new file mode 100644 index 00000000..c595f518 --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-5-5-Table1.yaml @@ -0,0 +1,26 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 5 +shot_num: 5 +query_num: 15 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-5-5-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-miniImageNet--ravi-5-5-highershot-Table1.yaml b/reproduce/TPN/TPN-miniImageNet--ravi-5-5-highershot-Table1.yaml new file mode 100644 index 00000000..569ee0eb --- /dev/null +++ b/reproduce/TPN/TPN-miniImageNet--ravi-5-5-highershot-Table1.yaml @@ -0,0 +1,27 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +way_num: 5 +shot_num: 10 +query_num: 15 +test_shot: 5 + +data_root: /data/fewshot/miniImageNet--ravi +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-miniImageNet--ravi-5-5-highershot-Table1 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-10-1-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-10-1-Table2.yaml new file mode 100644 index 00000000..2142c735 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-10-1-Table2.yaml @@ -0,0 +1,32 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 10 +shot_num: 1 +query_num: 15 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-10-1-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-10-1-highershot-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-10-1-highershot-Table2.yaml new file mode 100644 index 00000000..94764a1f --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-10-1-highershot-Table2.yaml @@ -0,0 +1,33 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 10 +shot_num: 5 +query_num: 15 +test_shot: 1 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-10-1-highershot-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-10-5-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-10-5-Table2.yaml new file mode 100644 index 00000000..1d4df020 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-10-5-Table2.yaml @@ -0,0 +1,32 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 10 +shot_num: 5 +query_num: 15 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-10-5-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-10-5-highershot-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-10-5-highershot-Table2.yaml new file mode 100644 index 00000000..500866c7 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-10-5-highershot-Table2.yaml @@ -0,0 +1,33 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 10 +shot_num: 10 +query_num: 15 +test_shot: 5 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-10-5-highershot-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-5-1-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-5-1-Table2.yaml new file mode 100644 index 00000000..dc8640b2 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-5-1-Table2.yaml @@ -0,0 +1,32 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 5 +shot_num: 1 +query_num: 15 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-5-1-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-5-1-highershot-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-5-1-highershot-Table2.yaml new file mode 100644 index 00000000..f5851483 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-5-1-highershot-Table2.yaml @@ -0,0 +1,33 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 5 +shot_num: 5 +query_num: 15 +test_shot: 1 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-5-1-highershot-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-5-5-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-5-5-Table2.yaml new file mode 100644 index 00000000..61700646 --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-5-5-Table2.yaml @@ -0,0 +1,32 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 5 +shot_num: 5 +query_num: 15 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-5-5-Table2 + +result_root: ./results +save_interval: 100 diff --git a/reproduce/TPN/TPN-tieredImageNet-5-5-highershot-Table2.yaml b/reproduce/TPN/TPN-tieredImageNet-5-5-highershot-Table2.yaml new file mode 100644 index 00000000..ffce4d1e --- /dev/null +++ b/reproduce/TPN/TPN-tieredImageNet-5-5-highershot-Table2.yaml @@ -0,0 +1,33 @@ +includes: + - headers/data.yaml + - headers/device.yaml + - headers/misc.yaml + - headers/model.yaml + - headers/optimizer.yaml + - TPN.yaml + +lr_scheduler: + name: StepLR + kwargs: + step_size: 25000 + gamma: 0.5 + +way_num: 5 +shot_num: 10 +query_num: 15 +test_shot: 5 + +data_root: /data/fewshot/tiered_imagenet +use_memory: false + +seed: 0 + +n_gpu: 1 +device_ids: 0 + +log_interval: 100 +log_level: info +log_name: TPN-tieredImageNet-5-5-highershot-Table2 + +result_root: ./results +save_interval: 100 diff --git a/run_test.py b/run_test.py index 958c87f2..4d5d159b 100644 --- a/run_test.py +++ b/run_test.py @@ -16,6 +16,7 @@ "n_gpu": 2, "test_episode": 600, "episode_size": 2, + "checkpoint_type": "best", # best, last or an epoch number }