Skip to content

Commit 5cf8384

Browse files
committed
add a vit with decorrelation auxiliary losses for mha and feedforwards, right after prenorm - this is in line with a paper from the netherlands, but without extra parameters or their manual sgd update scheme
1 parent f7d59ce commit 5cf8384

File tree

4 files changed

+342
-1
lines changed

4 files changed

+342
-1
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,4 +2201,16 @@ Coming from computer vision and new to transformers? Here are some resources tha
22012201
}
22022202
```
22032203

2204+
```bibtex
2205+
@misc{carrigg2025decorrelationspeedsvisiontransformers,
2206+
title = {Decorrelation Speeds Up Vision Transformers},
2207+
author = {Kieran Carrigg and Rob van Gastel and Melda Yeghaian and Sander Dalm and Faysal Boughorbel and Marcel van Gerven},
2208+
year = {2025},
2209+
eprint = {2510.14657},
2210+
archivePrefix = {arXiv},
2211+
primaryClass = {cs.CV},
2212+
url = {https://arxiv.org/abs/2510.14657},
2213+
}
2214+
```
2215+
22042216
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.14.5"
7+
version = "1.15.2"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

train_vit_decorr.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# /// script
2+
# dependencies = [
3+
# "accelerate",
4+
# "vit-pytorch",
5+
# "wandb"
6+
# ]
7+
# ///
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch.utils.data import DataLoader
12+
13+
import torchvision.transforms as T
14+
from torchvision.datasets import CIFAR100
15+
16+
# constants
17+
18+
BATCH_SIZE = 32
19+
LEARNING_RATE = 3e-4
20+
EPOCHS = 10
21+
DECORR_LOSS_WEIGHT = 1e-1
22+
23+
TRACK_EXPERIMENT_ONLINE = False
24+
25+
# helpers
26+
27+
def exists(v):
28+
return v is not None
29+
30+
# data
31+
32+
transform = T.Compose([
33+
T.ToTensor(),
34+
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
35+
])
36+
37+
dataset = CIFAR100(
38+
root = 'data',
39+
download = True,
40+
train = True,
41+
transform = transform
42+
)
43+
44+
dataloader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)
45+
46+
# model
47+
48+
from vit_pytorch.vit_with_decorr import ViT
49+
50+
vit = ViT(
51+
dim = 128,
52+
num_classes = 100,
53+
image_size = 32,
54+
patch_size = 4,
55+
depth = 6,
56+
heads = 8,
57+
dim_head = 64,
58+
mlp_dim = 128 * 4,
59+
decorr_sample_frac = 1. # use all tokens
60+
)
61+
62+
# optim
63+
64+
from torch.optim import Adam
65+
66+
optim = Adam(vit.parameters(), lr = LEARNING_RATE)
67+
68+
# prepare
69+
70+
from accelerate import Accelerator
71+
72+
accelerator = Accelerator()
73+
74+
vit, optim, dataloader = accelerator.prepare(vit, optim, dataloader)
75+
76+
# experiment
77+
78+
import wandb
79+
80+
wandb.init(
81+
project = 'vit-decorr',
82+
mode = 'disabled' if not TRACK_EXPERIMENT_ONLINE else 'online'
83+
)
84+
85+
wandb.run.name = 'baseline'
86+
87+
# loop
88+
89+
for _ in range(EPOCHS):
90+
for images, labels in dataloader:
91+
92+
logits, decorr_aux_loss = vit(images)
93+
loss = F.cross_entropy(logits, labels)
94+
95+
96+
total_loss = (
97+
loss +
98+
decorr_aux_loss * DECORR_LOSS_WEIGHT
99+
)
100+
101+
wandb.log(dict(loss = loss, decorr_loss = decorr_aux_loss))
102+
103+
accelerator.print(f'loss: {loss.item():.3f} | decorr aux loss: {decorr_aux_loss.item():.3f}')
104+
105+
accelerator.backward(total_loss)
106+
optim.step()
107+
optim.zero_grad()

vit_pytorch/vit_with_decorr.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# https://arxiv.org/abs/2510.14657
2+
# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss
3+
4+
import torch
5+
from torch import nn, stack
6+
import torch.nn.functional as F
7+
from torch.nn import Module, ModuleList
8+
9+
from einops import rearrange, repeat, reduce, einsum, pack, unpack
10+
from einops.layers.torch import Rearrange
11+
12+
# helpers
13+
14+
def exists(v):
15+
return v is not None
16+
17+
def default(v, d):
18+
return v if exists(v) else d
19+
20+
def pair(t):
21+
return t if isinstance(t, tuple) else (t, t)
22+
23+
# decorr loss
24+
25+
class DecorrelationLoss(Module):
26+
def __init__(
27+
self,
28+
sample_frac = 1.
29+
):
30+
super().__init__()
31+
assert 0. <= sample_frac <= 1.
32+
self.need_sample = sample_frac < 1.
33+
self.sample_frac = sample_frac
34+
35+
def forward(
36+
self,
37+
tokens
38+
):
39+
batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device
40+
41+
if self.need_sample:
42+
num_sampled = int(seq_len * self.sample_frac)
43+
assert num_sampled >= 2.
44+
45+
tokens, packed_shape = pack([tokens], '* n d e')
46+
47+
indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :]
48+
49+
batch_arange = torch.arange(tokens.shape[0], device = tokens.device)
50+
batch_arange = rearrange(batch_arange, 'b -> b 1')
51+
52+
tokens = tokens[batch_arange, indices]
53+
tokens, = unpack(tokens, packed_shape, '* n d e')
54+
55+
dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2]
56+
eye = torch.eye(dim, device = device)
57+
58+
loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim)
59+
60+
loss = reduce(loss, '... b d e -> b', 'sum')
61+
return loss.mean()
62+
63+
# classes
64+
65+
class FeedForward(Module):
66+
def __init__(self, dim, hidden_dim, dropout = 0.):
67+
super().__init__()
68+
self.norm = nn.LayerNorm(dim)
69+
70+
self.net = nn.Sequential(
71+
nn.Linear(dim, hidden_dim),
72+
nn.GELU(),
73+
nn.Dropout(dropout),
74+
nn.Linear(hidden_dim, dim),
75+
nn.Dropout(dropout)
76+
)
77+
78+
def forward(self, x):
79+
normed = self.norm(x)
80+
return self.net(x), normed
81+
82+
class Attention(Module):
83+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
84+
super().__init__()
85+
inner_dim = dim_head * heads
86+
project_out = not (heads == 1 and dim_head == dim)
87+
88+
self.norm = nn.LayerNorm(dim)
89+
self.heads = heads
90+
self.scale = dim_head ** -0.5
91+
92+
self.attend = nn.Softmax(dim = -1)
93+
self.dropout = nn.Dropout(dropout)
94+
95+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
96+
97+
self.to_out = nn.Sequential(
98+
nn.Linear(inner_dim, dim),
99+
nn.Dropout(dropout)
100+
) if project_out else nn.Identity()
101+
102+
def forward(self, x):
103+
normed = self.norm(x)
104+
105+
qkv = self.to_qkv(normed).chunk(3, dim = -1)
106+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
107+
108+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
109+
110+
attn = self.attend(dots)
111+
attn = self.dropout(attn)
112+
113+
out = torch.matmul(attn, v)
114+
out = rearrange(out, 'b h n d -> b n (h d)')
115+
116+
return self.to_out(out), normed
117+
118+
class Transformer(Module):
119+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
120+
super().__init__()
121+
self.norm = nn.LayerNorm(dim)
122+
self.layers = ModuleList([])
123+
124+
for _ in range(depth):
125+
self.layers.append(ModuleList([
126+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
127+
FeedForward(dim, mlp_dim, dropout = dropout)
128+
]))
129+
130+
def forward(self, x):
131+
132+
normed_inputs = []
133+
134+
for attn, ff in self.layers:
135+
attn_out, attn_normed_inp = attn(x)
136+
x = attn_out + x
137+
138+
ff_out, ff_normed_inp = ff(x)
139+
x = ff_out + x
140+
141+
normed_inputs.append(attn_normed_inp)
142+
normed_inputs.append(ff_normed_inp)
143+
144+
return self.norm(x), stack(normed_inputs)
145+
146+
class ViT(Module):
147+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.):
148+
super().__init__()
149+
image_height, image_width = pair(image_size)
150+
patch_height, patch_width = pair(patch_size)
151+
152+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
153+
154+
num_patches = (image_height // patch_height) * (image_width // patch_width)
155+
patch_dim = channels * patch_height * patch_width
156+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
157+
158+
self.to_patch_embedding = nn.Sequential(
159+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
160+
nn.LayerNorm(patch_dim),
161+
nn.Linear(patch_dim, dim),
162+
nn.LayerNorm(dim),
163+
)
164+
165+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
166+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
167+
self.dropout = nn.Dropout(emb_dropout)
168+
169+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
170+
171+
self.pool = pool
172+
self.to_latent = nn.Identity()
173+
174+
self.mlp_head = nn.Linear(dim, num_classes)
175+
176+
# decorrelation loss related
177+
178+
self.has_decorr_loss = decorr_sample_frac > 0.
179+
180+
if self.has_decorr_loss:
181+
self.decorr_loss = DecorrelationLoss(decorr_sample_frac)
182+
183+
self.register_buffer('zero', torch.tensor(0.), persistent = False)
184+
185+
def forward(
186+
self,
187+
img,
188+
return_decorr_aux_loss = None
189+
):
190+
return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss
191+
192+
x = self.to_patch_embedding(img)
193+
b, n, _ = x.shape
194+
195+
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
196+
x = torch.cat((cls_tokens, x), dim=1)
197+
x += self.pos_embedding[:, :(n + 1)]
198+
x = self.dropout(x)
199+
200+
x, normed_layer_inputs = self.transformer(x)
201+
202+
# maybe return decor loss
203+
204+
decorr_aux_loss = self.zero
205+
206+
if return_decorr_aux_loss:
207+
decorr_aux_loss = self.decorr_loss(normed_layer_inputs)
208+
209+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
210+
211+
x = self.to_latent(x)
212+
return self.mlp_head(x), decorr_aux_loss
213+
214+
# quick test
215+
216+
if __name__ == '__main__':
217+
decorr_loss = DecorrelationLoss(0.1)
218+
219+
hiddens = torch.randn(6, 2, 512, 256)
220+
221+
decorr_loss(hiddens)
222+
decorr_loss(hiddens[0])

0 commit comments

Comments
 (0)