|  | 
|  | 1 | +# https://arxiv.org/abs/2510.14657 | 
|  | 2 | +# but instead of their decorr module updated with SGD, remove all projections and just return a decorrelation auxiliary loss | 
|  | 3 | + | 
|  | 4 | +import torch | 
|  | 5 | +from torch import nn, stack | 
|  | 6 | +import torch.nn.functional as F | 
|  | 7 | +from torch.nn import Module, ModuleList | 
|  | 8 | + | 
|  | 9 | +from einops import rearrange, repeat, reduce, einsum, pack, unpack | 
|  | 10 | +from einops.layers.torch import Rearrange | 
|  | 11 | + | 
|  | 12 | +# helpers | 
|  | 13 | + | 
|  | 14 | +def exists(v): | 
|  | 15 | +    return v is not None | 
|  | 16 | + | 
|  | 17 | +def default(v, d): | 
|  | 18 | +    return v if exists(v) else d | 
|  | 19 | + | 
|  | 20 | +def pair(t): | 
|  | 21 | +    return t if isinstance(t, tuple) else (t, t) | 
|  | 22 | + | 
|  | 23 | +# decorr loss | 
|  | 24 | + | 
|  | 25 | +class DecorrelationLoss(Module): | 
|  | 26 | +    def __init__( | 
|  | 27 | +        self, | 
|  | 28 | +        sample_frac = 1. | 
|  | 29 | +    ): | 
|  | 30 | +        super().__init__() | 
|  | 31 | +        assert 0. <= sample_frac <= 1. | 
|  | 32 | +        self.need_sample = sample_frac < 1. | 
|  | 33 | +        self.sample_frac = sample_frac | 
|  | 34 | + | 
|  | 35 | +    def forward( | 
|  | 36 | +        self, | 
|  | 37 | +        tokens | 
|  | 38 | +    ): | 
|  | 39 | +        batch, seq_len, dim, device = *tokens.shape[-3:], tokens.device | 
|  | 40 | + | 
|  | 41 | +        if self.need_sample: | 
|  | 42 | +            num_sampled = int(seq_len * self.sample_frac) | 
|  | 43 | +            assert num_sampled >= 2. | 
|  | 44 | + | 
|  | 45 | +            tokens, packed_shape = pack([tokens], '* n d e') | 
|  | 46 | + | 
|  | 47 | +            indices = torch.randn(tokens.shape[:2]).argsort(dim = -1)[..., :num_sampled, :] | 
|  | 48 | + | 
|  | 49 | +            batch_arange = torch.arange(tokens.shape[0], device = tokens.device) | 
|  | 50 | +            batch_arange = rearrange(batch_arange, 'b -> b 1') | 
|  | 51 | + | 
|  | 52 | +            tokens = tokens[batch_arange, indices] | 
|  | 53 | +            tokens, = unpack(tokens, packed_shape, '* n d e') | 
|  | 54 | + | 
|  | 55 | +        dist = einsum(tokens, tokens, '... n d, ... n e -> ... d e') / tokens.shape[-2] | 
|  | 56 | +        eye = torch.eye(dim, device = device) | 
|  | 57 | + | 
|  | 58 | +        loss = dist.pow(2) * (1. - eye) / ((dim - 1) * dim) | 
|  | 59 | + | 
|  | 60 | +        loss = reduce(loss, '... b d e -> b', 'sum') | 
|  | 61 | +        return loss.mean() | 
|  | 62 | + | 
|  | 63 | +# classes | 
|  | 64 | + | 
|  | 65 | +class FeedForward(Module): | 
|  | 66 | +    def __init__(self, dim, hidden_dim, dropout = 0.): | 
|  | 67 | +        super().__init__() | 
|  | 68 | +        self.norm = nn.LayerNorm(dim) | 
|  | 69 | + | 
|  | 70 | +        self.net = nn.Sequential( | 
|  | 71 | +            nn.Linear(dim, hidden_dim), | 
|  | 72 | +            nn.GELU(), | 
|  | 73 | +            nn.Dropout(dropout), | 
|  | 74 | +            nn.Linear(hidden_dim, dim), | 
|  | 75 | +            nn.Dropout(dropout) | 
|  | 76 | +        ) | 
|  | 77 | + | 
|  | 78 | +    def forward(self, x): | 
|  | 79 | +        normed = self.norm(x) | 
|  | 80 | +        return self.net(x), normed | 
|  | 81 | + | 
|  | 82 | +class Attention(Module): | 
|  | 83 | +    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): | 
|  | 84 | +        super().__init__() | 
|  | 85 | +        inner_dim = dim_head *  heads | 
|  | 86 | +        project_out = not (heads == 1 and dim_head == dim) | 
|  | 87 | + | 
|  | 88 | +        self.norm = nn.LayerNorm(dim) | 
|  | 89 | +        self.heads = heads | 
|  | 90 | +        self.scale = dim_head ** -0.5 | 
|  | 91 | + | 
|  | 92 | +        self.attend = nn.Softmax(dim = -1) | 
|  | 93 | +        self.dropout = nn.Dropout(dropout) | 
|  | 94 | + | 
|  | 95 | +        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | 
|  | 96 | + | 
|  | 97 | +        self.to_out = nn.Sequential( | 
|  | 98 | +            nn.Linear(inner_dim, dim), | 
|  | 99 | +            nn.Dropout(dropout) | 
|  | 100 | +        ) if project_out else nn.Identity() | 
|  | 101 | + | 
|  | 102 | +    def forward(self, x): | 
|  | 103 | +        normed = self.norm(x) | 
|  | 104 | + | 
|  | 105 | +        qkv = self.to_qkv(normed).chunk(3, dim = -1) | 
|  | 106 | +        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) | 
|  | 107 | + | 
|  | 108 | +        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | 
|  | 109 | + | 
|  | 110 | +        attn = self.attend(dots) | 
|  | 111 | +        attn = self.dropout(attn) | 
|  | 112 | + | 
|  | 113 | +        out = torch.matmul(attn, v) | 
|  | 114 | +        out = rearrange(out, 'b h n d -> b n (h d)') | 
|  | 115 | + | 
|  | 116 | +        return self.to_out(out), normed | 
|  | 117 | + | 
|  | 118 | +class Transformer(Module): | 
|  | 119 | +    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): | 
|  | 120 | +        super().__init__() | 
|  | 121 | +        self.norm = nn.LayerNorm(dim) | 
|  | 122 | +        self.layers = ModuleList([]) | 
|  | 123 | + | 
|  | 124 | +        for _ in range(depth): | 
|  | 125 | +            self.layers.append(ModuleList([ | 
|  | 126 | +                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), | 
|  | 127 | +                FeedForward(dim, mlp_dim, dropout = dropout) | 
|  | 128 | +            ])) | 
|  | 129 | + | 
|  | 130 | +    def forward(self, x): | 
|  | 131 | + | 
|  | 132 | +        normed_inputs = [] | 
|  | 133 | + | 
|  | 134 | +        for attn, ff in self.layers: | 
|  | 135 | +            attn_out, attn_normed_inp = attn(x) | 
|  | 136 | +            x = attn_out + x | 
|  | 137 | + | 
|  | 138 | +            ff_out, ff_normed_inp = ff(x) | 
|  | 139 | +            x = ff_out + x | 
|  | 140 | + | 
|  | 141 | +            normed_inputs.append(attn_normed_inp) | 
|  | 142 | +            normed_inputs.append(ff_normed_inp) | 
|  | 143 | + | 
|  | 144 | +        return self.norm(x), stack(normed_inputs) | 
|  | 145 | + | 
|  | 146 | +class ViT(Module): | 
|  | 147 | +    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., decorr_sample_frac = 1.): | 
|  | 148 | +        super().__init__() | 
|  | 149 | +        image_height, image_width = pair(image_size) | 
|  | 150 | +        patch_height, patch_width = pair(patch_size) | 
|  | 151 | + | 
|  | 152 | +        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' | 
|  | 153 | + | 
|  | 154 | +        num_patches = (image_height // patch_height) * (image_width // patch_width) | 
|  | 155 | +        patch_dim = channels * patch_height * patch_width | 
|  | 156 | +        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' | 
|  | 157 | + | 
|  | 158 | +        self.to_patch_embedding = nn.Sequential( | 
|  | 159 | +            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), | 
|  | 160 | +            nn.LayerNorm(patch_dim), | 
|  | 161 | +            nn.Linear(patch_dim, dim), | 
|  | 162 | +            nn.LayerNorm(dim), | 
|  | 163 | +        ) | 
|  | 164 | + | 
|  | 165 | +        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | 
|  | 166 | +        self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | 
|  | 167 | +        self.dropout = nn.Dropout(emb_dropout) | 
|  | 168 | + | 
|  | 169 | +        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) | 
|  | 170 | + | 
|  | 171 | +        self.pool = pool | 
|  | 172 | +        self.to_latent = nn.Identity() | 
|  | 173 | + | 
|  | 174 | +        self.mlp_head = nn.Linear(dim, num_classes) | 
|  | 175 | + | 
|  | 176 | +        # decorrelation loss related | 
|  | 177 | + | 
|  | 178 | +        self.has_decorr_loss = decorr_sample_frac > 0. | 
|  | 179 | + | 
|  | 180 | +        if self.has_decorr_loss: | 
|  | 181 | +            self.decorr_loss = DecorrelationLoss(decorr_sample_frac) | 
|  | 182 | + | 
|  | 183 | +        self.register_buffer('zero', torch.tensor(0.), persistent = False) | 
|  | 184 | + | 
|  | 185 | +    def forward( | 
|  | 186 | +        self, | 
|  | 187 | +        img, | 
|  | 188 | +        return_decorr_aux_loss = None | 
|  | 189 | +    ): | 
|  | 190 | +        return_decorr_aux_loss = default(return_decorr_aux_loss, self.training) and self.has_decorr_loss | 
|  | 191 | + | 
|  | 192 | +        x = self.to_patch_embedding(img) | 
|  | 193 | +        b, n, _ = x.shape | 
|  | 194 | + | 
|  | 195 | +        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) | 
|  | 196 | +        x = torch.cat((cls_tokens, x), dim=1) | 
|  | 197 | +        x += self.pos_embedding[:, :(n + 1)] | 
|  | 198 | +        x = self.dropout(x) | 
|  | 199 | + | 
|  | 200 | +        x, normed_layer_inputs = self.transformer(x) | 
|  | 201 | + | 
|  | 202 | +        # maybe return decor loss | 
|  | 203 | + | 
|  | 204 | +        decorr_aux_loss = self.zero | 
|  | 205 | + | 
|  | 206 | +        if return_decorr_aux_loss: | 
|  | 207 | +            decorr_aux_loss = self.decorr_loss(normed_layer_inputs) | 
|  | 208 | + | 
|  | 209 | +        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] | 
|  | 210 | + | 
|  | 211 | +        x = self.to_latent(x) | 
|  | 212 | +        return self.mlp_head(x), decorr_aux_loss | 
|  | 213 | + | 
|  | 214 | +# quick test | 
|  | 215 | + | 
|  | 216 | +if __name__ == '__main__': | 
|  | 217 | +    decorr_loss = DecorrelationLoss(0.1) | 
|  | 218 | + | 
|  | 219 | +    hiddens = torch.randn(6, 2, 512, 256) | 
|  | 220 | + | 
|  | 221 | +    decorr_loss(hiddens) | 
|  | 222 | +    decorr_loss(hiddens[0]) | 
0 commit comments