Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions config/TPN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
backbone:
name: Conv64F
kwargs:
is_flatten: false
is_feature: false
leaky_relu: false
negative_slope: 0.2
last_pool: true
maxpool_last2: true
use_running_statistics: true

classifier:
name: TPN
kwargs:
topk: 20
sigma: 0.25
alpha: 0.99
rn: 300

way_num: 5
shot_num: 1
query_num: 15

epoch: 1000
test_epoch: 100
train_episode: 100
test_episode: 100
episode_size: 5

optimizer:
name: Adam
kwargs:
lr: 1e-3

lr_scheduler:
name: StepLR
kwargs:
step_size: 10000
gamma: 0.5
1 change: 0 additions & 1 deletion core/model/backbone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .resnet_bdc import resnet12Bdc, resnet18Bdc
from core.model.backbone.utils.maml_module import convert_maml_module


def get_backbone(config):
"""Get the backbone according to the config dict.

Expand Down
3 changes: 2 additions & 1 deletion core/model/metric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
from .deepbdc import DeepBDC
from .frn import FRN
from .meta_baseline import MetaBaseline
from .mcl import MCL
from .mcl import MCL
from .tpn import TPN
197 changes: 197 additions & 0 deletions core/model/metric/tpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
@misc{liu2019learningpropagatelabelstransductive,
title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning},
author={Yanbin Liu and Juho Lee and Minseop Park and Saehoon Kim and Eunho Yang and Sung Ju Hwang and Yi Yang},
year={2019},
eprint={1805.10002},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/1805.10002},
}

Adapted From https://github.com/csyanbin/TPN-pytorch
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .metric_model import MetricModel


class RelationNetwork(nn.Module):
"""Graph Construction Module"""

def __init__(self):
super(RelationNetwork, self).__init__()

self.layer1 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, padding=1),
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, 1, kernel_size=3, padding=1),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, padding=1),
)

self.fc3 = nn.Linear(2 * 2, 8)
self.fc4 = nn.Linear(8, 1)

def forward(self, x):
x = x.view(-1, 64, 5, 5)

out = self.layer1(x)
out = self.layer2(out)

out = out.view(out.size(0), -1)
out = F.relu(self.fc3(out))
out = self.fc4(out)

out = out.view(out.size(0), -1)

return out


class TPN(MetricModel):
def __init__(self, alpha, **kwargs):
super().__init__(**kwargs)

self.relation = RelationNetwork()
self.eps = torch.finfo(torch.float32).eps

if self.rn == 300:
self.alpha = torch.tensor([alpha], requires_grad=False).to(self.device)
elif self.rn == 30:
self.alpha = nn.Parameter(
torch.tensor([alpha]).to(self.device), requires_grad=True
)

def labels_to_onehot(self, labels):
batch_size = labels.size(0)
one_hot = torch.zeros(batch_size, self.way_num).to(self.device)
one_hot.scatter_(1, labels.unsqueeze(1), 1)

return one_hot

def label_propagation(self, support, query, support_label, query_label):
input_feat = torch.cat((support, query), 0)
embedding_all = self.emb_func(input_feat).view(-1, 1600)
num_nodes = embedding_all.shape[0]

if self.rn in [30, 300]:
self.sigma = self.relation(embedding_all)
embedding_all = embedding_all / (self.sigma + self.eps)

weight_matrix = torch.cdist(embedding_all, embedding_all, p=2) ** 2
weight_matrix = torch.exp(-weight_matrix / 2)

if self.topk > 0:
topk_values, topk_indices = torch.topk(weight_matrix, self.topk)
mask = torch.zeros_like(weight_matrix)
mask = mask.scatter(1, topk_indices, 1)
mask = (mask + mask.t()) > 0
weight_matrix = weight_matrix * mask

degree_matrix = weight_matrix.sum(0)
degree_sqrt_inv = torch.rsqrt(degree_matrix + self.eps)

symmetric_matrix = weight_matrix * degree_sqrt_inv.unsqueeze(0) * degree_sqrt_inv.unsqueeze(1)

support_labels = support_label
unlabeled_query = torch.zeros(self.way_num * self.query_num, self.way_num,
device=self.device, dtype=support_labels.dtype)
combined_labels = torch.cat((support_labels, unlabeled_query), 0)

identity_matrix = torch.eye(num_nodes, device=self.device)
A = identity_matrix - self.alpha * symmetric_matrix
propagated_labels = torch.linalg.solve(A, combined_labels)

query_predictions = propagated_labels[self.way_num * self.shot_num:, :]

all_labels = torch.cat((support_label, query_label), 0)
ground_truth = torch.argmax(all_labels, 1)
criterion = nn.CrossEntropyLoss()
loss = criterion(propagated_labels, ground_truth)

predicted_query = torch.argmax(query_predictions, 1)
ground_truth_query = torch.argmax(query_label, 1)

accuracy = (predicted_query == ground_truth_query).float().mean()

return loss, accuracy

def set_forward_loss(self, batch):
image, global_target = batch
image = image.to(self.device)

episode_size = image.size(0) // (
self.way_num * (self.shot_num + self.query_num)
)

(
support_image,
query_image,
support_target,
query_target,
) = self.split_by_episode(image, mode=2)

loss_list = torch.zeros(episode_size, device=self.device)
acc_list = torch.zeros(episode_size, device=self.device)

for i in range(episode_size):
support_label_onehot = self.labels_to_onehot(support_target[i])
query_label_onehot = self.labels_to_onehot(query_target[i])

loss, acc = self.label_propagation(
support_image[i],
query_image[i],
support_label_onehot,
query_label_onehot,
)

loss_list[i] = loss
acc_list[i] = acc

loss = loss_list.mean()
acc = acc_list.mean() * 100.0

return None, acc, loss

def set_forward(self, batch):
image, global_target = batch
image = image.to(self.device)

episode_size = image.size(0) // (
self.way_num * (self.shot_num + self.query_num)
)

(
support_image,
query_image,
support_target,
query_target,
) = self.split_by_episode(image, mode=2)

acc_list = torch.zeros(episode_size, device=self.device)

for i in range(episode_size):
support_label_onehot = self.labels_to_onehot(support_target[i])
query_label_onehot = self.labels_to_onehot(query_target[i])

_, acc = self.label_propagation(
support_image[i],
query_image[i],
support_label_onehot,
query_label_onehot,
)

acc_list[i] = acc

acc = acc_list.mean() * 100.0

return None, acc
17 changes: 16 additions & 1 deletion core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,22 @@ def _init_files(self, config):
rank=self.rank,
)

state_dict_path = os.path.join(result_path, "checkpoints", "model_best.pth")
checkpoint_type = config.get("checkpoint_type", "best")
if checkpoint_type == "best":
checkpoint_filename = "model_best.pth"
elif checkpoint_type == "last":
checkpoint_filename = "model_last.pth"
elif isinstance(checkpoint_type, int) or checkpoint_type.isdigit():
epoch_num = int(checkpoint_type)
checkpoint_filename = f"model_{epoch_num:05d}.pth"
else:
print(
f"Warning: Invalid checkpoint_type '{checkpoint_type}', using 'best'",
level="warning",
)
checkpoint_filename = "model_best.pth"

state_dict_path = os.path.join(result_path, "checkpoints", checkpoint_filename)
if self.rank == 0:
create_dirs([result_path, log_path, viz_path])

Expand Down
45 changes: 45 additions & 0 deletions reproduce/TPN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TPN Reproduction

## Introduction

| Name: | [TPN](https://arxiv.org/abs/1805.10002) |
| ------- | ---------------------------------------------------------- |
| Embed.: | Conv64F |
| Type: | Metric |
| Venue: | ICLR2019 |
| Codes: | [**TPN-pytorch**](https://github.com/csyanbin/TPN-pytorch) |

Cite this work with (template):

```bibtex
@misc{liu2019learningpropagatelabelstransductive,
title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning},
author={Yanbin Liu and Juho Lee and Minseop Park and Saehoon Kim and Eunho Yang and Sung Ju Hwang and Yi Yang},
year={2019},
eprint={1805.10002},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/1805.10002},
}
```

---

## Results and Models

All the results are tested under the best model. Checkpoints of different epochs are also provided.

**TPN Result**

| dataset/task | 5way-1shot | 5way-5shot | 10way-1shot | 10-way-5shot |
| -------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| miniImageNet | 54.06±0.38 [:arrow_down:](https://drive.google.com/drive/folders/1Y14e-h_DcwfyxwU71GZIXQ2G39ZS8oys) | 69.27±0.30 [:arrow_down:](https://drive.google.com/drive/folders/1DIBmJ8a_GZIlEmUW0KEf_awTaB-8GB7i) | 37.36±0.23 [:arrow_down:](https://drive.google.com/drive/folders/1OeO3K7wY4y-UN979eRUQo1vvVvin3vF-) | 53.62±0.20 [:arrow_down:](https://drive.google.com/drive/folders/1c8yd0rMQhAytePcrfLq2nEYaxPDFHc8f) |
| tieredImageNet | 53.36±0.42 [:arrow_down:](https://drive.google.com/drive/folders/1_C0VA1LirJ5l3kYqEBY8HfHmXOzCkZog) | 69.83±0.35 [:arrow_down:](https://drive.google.com/drive/folders/1anwG8tjvaXQ5oq9BBjcTQjY9OdKmDYf1) | 40.29±0.28 [:arrow_down:](https://drive.google.com/drive/folders/1HnJfDHHE78YzaGvhmHHvaG4ymVhdPkrh) | 57.53±0.25 [:arrow_down:](https://drive.google.com/drive/folders/1NBaxyY60rJIxnoeOV0UAAt6KORPvZyfJ) |

**Higher Shot TPN Result**

| dataset/task | 5way-1shot | 5way-5shot | 10way-1shot | 10-way-5shot |
| -------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| miniImageNet | 55.37±0.39 [:arrow_down:](https://drive.google.com/drive/folders/1sMAcvy817oBzMybkO_HVQ0AMBhJW0kbr?usp=drive_link) | 68.80±0.30 [:arrow_down:](https://drive.google.com/drive/folders/1G1WMnHbsgJcgbdSWoVqRWkQJldDvbDL1?usp=drive_link) | 38.62±0.22 [:arrow_down:](https://drive.google.com/drive/folders/1aXKD2GBbsM7Ql0FdpZV1NJeEdu5QZMyV?usp=drive_link) | 53.68±0.21 [:arrow_down:](https://drive.google.com/drive/folders/1mi2r9Kbl6SvSdbn5kKRd3DaqS354vsxe?usp=drive_link) |
| tieredImageNet | 57.17±0.42 [:arrow_down:](https://drive.google.com/drive/folders/1G-zBq1Hp0zrPIgN1UUlOIeuXnusRJfmB?usp=drive_link) | 69.67±0.34 [:arrow_down:](https://drive.google.com/drive/folders/1-1e22Tjs1YtVuh6HujQ5hMCXWORlKVBR?usp=drive_link) | 43.40±0.28 [:arrow_down:](https://drive.google.com/drive/folders/1hY0q3Nd84SqBSTVmF1ZveKkxc5ZH5raf?usp=drive_link) | 57.71±0.25 [:arrow_down:](https://drive.google.com/drive/folders/1bwoHxwTBSBtPHOyv-0B1v2ZS2IMbAlO1?usp=drive_link) |

26 changes: 26 additions & 0 deletions reproduce/TPN/TPN-miniImageNet--ravi-10-1-Table1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
includes:
- headers/data.yaml
- headers/device.yaml
- headers/misc.yaml
- headers/model.yaml
- headers/optimizer.yaml
- TPN.yaml

way_num: 10
shot_num: 1
query_num: 15

data_root: /data/fewshot/miniImageNet--ravi
use_memory: false

seed: 0

n_gpu: 1
device_ids: 0

log_interval: 100
log_level: info
log_name: TPN-miniImageNet--ravi-10-1-Table1

result_root: ./results
save_interval: 100
27 changes: 27 additions & 0 deletions reproduce/TPN/TPN-miniImageNet--ravi-10-1-highershot-Table1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
includes:
- headers/data.yaml
- headers/device.yaml
- headers/misc.yaml
- headers/model.yaml
- headers/optimizer.yaml
- TPN.yaml

way_num: 10
shot_num: 5
query_num: 15
test_shot: 1

data_root: /data/fewshot/miniImageNet--ravi
use_memory: false

seed: 0

n_gpu: 1
device_ids: 0

log_interval: 100
log_level: info
log_name: TPN-miniImageNet--ravi-10-1-highershot-Table1

result_root: ./results
save_interval: 100
Loading