diff --git a/models/swin_transformer.py b/models/swin_transformer.py index dde06bc5b..d3345c3ab 100644 --- a/models/swin_transformer.py +++ b/models/swin_transformer.py @@ -162,13 +162,13 @@ def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim + flops += N * self.dim * 3 * self.dim * 2 # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N + flops += self.num_heads * N * (self.dim // self.num_heads) * N * 2 # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) + flops += self.num_heads * N * N * (self.dim // self.num_heads) * 2 # x = self.proj(x) - flops += N * self.dim * self.dim + flops += N * self.dim * self.dim * 2 return flops @@ -306,7 +306,7 @@ def flops(self): nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio * 2 # norm2 flops += self.dim * H * W return flops @@ -357,7 +357,7 @@ def extra_repr(self) -> str: def flops(self): H, W = self.input_resolution flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim * 2 return flops @@ -476,7 +476,7 @@ def forward(self, x): def flops(self): Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) * 2 if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops