Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -694,21 +694,22 @@ from vit_pytorch import ViT
from vit_pytorch.simmim import SimMIM

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
image_size=224,
patch_size=16,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
)

mim = SimMIM(
encoder = v,
masking_ratio = 0.5 # they found 50% to yield the best results
encoder=v,
encoder_stride=16, # for swin transformer, it should be 32
masking_ratio=0.5, # they found 50% to yield the best results
)

images = torch.randn(8, 3, 256, 256)
images = torch.randn(8, 3, 224, 224)

loss = mim(images)
loss.backward()
Expand Down
129 changes: 98 additions & 31 deletions vit_pytorch/simmim.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from torch import nn


class SimMIM(nn.Module):
def __init__(
self,
*,
encoder,
masking_ratio = 0.5
):
def __init__(self, *, encoder, encoder_stride, in_chans=3, masking_ratio=0.5):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
assert (
masking_ratio > 0 and masking_ratio < 1
), "masking ratio must be kept between 0 and 1"
self.masking_ratio = masking_ratio

# extract some hyperparameters and functions from encoder (vision transformer to be trained)
Expand All @@ -21,6 +19,16 @@ def __init__(
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]

self.in_chans = in_chans
self.encoder_stride = encoder_stride
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=encoder_dim,
out_channels=self.encoder_stride ** 2 * 3,
kernel_size=1,
),
nn.PixelShuffle(self.encoder_stride),
)
# simple linear head

self.mask_token = nn.Parameter(torch.randn(encoder_dim))
Expand All @@ -36,11 +44,11 @@ def forward(self, img):

# for indexing purposes

batch_range = torch.arange(batch, device = device)[:, None]
batch_range = torch.arange(batch, device=device)[:, None]

# get positions

pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]
pos_emb = self.encoder.pos_embedding[:, 1 : (num_patches + 1)]

# patch to encoder tokens and add positions

Expand All @@ -49,36 +57,95 @@ def forward(self, img):

# prepare mask tokens

mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
mask_tokens = repeat(self.mask_token, "d -> b n d", b=batch, n=num_patches)
mask_tokens = mask_tokens + pos_emb

# calculate of patches needed to be masked, and get positions (indices) to be masked

num_masked = int(self.masking_ratio * num_patches)
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()
masked_indices = (
torch.rand(batch, num_patches, device=device)
.topk(k=num_masked, dim=-1)
.indices
)
masked_bool_mask = (
torch.zeros((batch, num_patches), device=device)
.scatter_(-1, masked_indices, 1)
.bool()
)

# mask tokens

tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)

# attend with vision transformer

encoded = self.encoder.transformer(tokens)

# get the masked tokens

encoded_mask_tokens = encoded[batch_range, masked_indices]

# small linear projection for predicted pixel values

pred_pixel_values = self.to_pixels(encoded_mask_tokens)

# get the masked patches for the final reconstruction loss

masked_patches = patches[batch_range, masked_indices]

# calculate reconstruction loss

recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
return recon_loss
# encoded = encoded[:, 1:]
B, L, C = encoded.shape
H = W = int(L ** 0.5)
z = encoded.permute(0, 2, 1).reshape(B, C, H, W)

x_rec = self.decoder(z)

loss_recon = F.l1_loss(img, x_rec)
return loss_recon / self.in_chans
# mask weight
# patch_size = int(num_patches ** 0.5)
# mask_lst = []
# for i, masked_indice in enumerate(masked_indices):
# mask = torch.ones(num_patches)
# mask[masked_indice] = 0
# mask_lst.append(mask.view(patch_size, patch_size))
# mask = torch.stack(mask_lst, dim=0).to(device)

# mask = (
# mask.repeat_interleave(patch_size, 1)
# .repeat_interleave(patch_size, 2)
# .unsqueeze(1)
# .contiguous()
# )
# loss_recon = F.l1_loss(img, x_rec, reduction="none")
# return (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans

# # simple head
# # get the masked tokens
# encoded_mask_tokens = encoded[batch_range, masked_indices]
# # small linear projection for predicted pixel values
# pred_pixel_values = self.to_pixels(encoded_mask_tokens)

# # get the masked patches for the final reconstruction loss

# masked_patches = patches[batch_range, masked_indices]

# # calculate reconstruction loss

# recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
# return recon_loss


if __name__ == "__main__":
import torch

from vit import ViT

v = ViT(
image_size=224,
patch_size=16,
num_classes=1000,
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
)

mim = SimMIM(
encoder=v,
encoder_stride=16, # for swin transformer, it should be 32
masking_ratio=0.5, # they found 50% to yield the best results
)

images = torch.randn(8, 3, 224, 224)

loss = mim(images)
loss.backward()
print(loss)