2023-03-26 23:31:59 +04:00
|
|
|
"""
|
|
|
|
|
SwiftFormer
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
import copy
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
|
from timm.models.layers.helpers import to_2tuple
|
|
|
|
|
import einops
|
|
|
|
|
|
|
|
|
|
SwiftFormer_width = {
|
|
|
|
|
'XS': [48, 56, 112, 220],
|
|
|
|
|
'S': [48, 64, 168, 224],
|
|
|
|
|
'l1': [48, 96, 192, 384],
|
|
|
|
|
'l3': [64, 128, 320, 512],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SwiftFormer_depth = {
|
|
|
|
|
'XS': [3, 3, 6, 4],
|
|
|
|
|
'S': [3, 3, 9, 6],
|
|
|
|
|
'l1': [4, 3, 10, 5],
|
|
|
|
|
'l3': [4, 4, 12, 6],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def stem(in_chs, out_chs):
|
|
|
|
|
"""
|
|
|
|
|
Stem Layer that is implemented by two layers of conv.
|
|
|
|
|
Output: sequence of layers with final shape of [B, C, H/4, W/4]
|
|
|
|
|
"""
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
|
|
|
|
|
nn.BatchNorm2d(out_chs // 2),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
|
|
|
|
|
nn.BatchNorm2d(out_chs),
|
|
|
|
|
nn.ReLU(), )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Embedding(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Patch Embedding that is implemented by a layer of conv.
|
|
|
|
|
Input: tensor in shape [B, C, H, W]
|
|
|
|
|
Output: tensor in shape [B, C, H/stride, W/stride]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, patch_size=16, stride=16, padding=0,
|
|
|
|
|
in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super().__init__()
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
stride = to_2tuple(stride)
|
|
|
|
|
padding = to_2tuple(padding)
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
|
|
|
|
|
stride=stride, padding=padding)
|
|
|
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvEncoder(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
|
|
|
|
|
Input: tensor with shape [B, C, H, W]
|
|
|
|
|
Output: tensor with shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
|
|
|
|
|
self.norm = nn.BatchNorm2d(dim)
|
|
|
|
|
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
|
|
|
|
|
self.act = nn.GELU()
|
|
|
|
|
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
input = x
|
|
|
|
|
x = self.dwconv(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
x = self.pwconv1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.pwconv2(x)
|
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = input + self.drop_path(self.layer_scale * x)
|
|
|
|
|
else:
|
|
|
|
|
x = input + self.drop_path(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Implementation of MLP layer with 1*1 convolutions.
|
|
|
|
|
Input: tensor with shape [B, C, H, W]
|
|
|
|
|
Output: tensor with shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_features, hidden_features=None,
|
|
|
|
|
out_features=None, act_layer=nn.GELU, drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
out_features = out_features or in_features
|
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
self.norm1 = nn.BatchNorm2d(in_features)
|
|
|
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
|
|
|
|
self.act = act_layer()
|
|
|
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.norm1(x)
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EfficientAdditiveAttnetion(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Efficient Additive Attention module for SwiftFormer.
|
2023-07-26 00:59:51 +04:00
|
|
|
Input: tensor in shape [B, N, D]
|
|
|
|
|
Output: tensor in shape [B, N, D]
|
2023-03-26 23:31:59 +04:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_dims=512, token_dim=256, num_heads=2):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.to_query = nn.Linear(in_dims, token_dim * num_heads)
|
|
|
|
|
self.to_key = nn.Linear(in_dims, token_dim * num_heads)
|
|
|
|
|
|
|
|
|
|
self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
|
|
|
|
|
self.scale_factor = token_dim ** -0.5
|
|
|
|
|
self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
|
|
|
|
|
self.final = nn.Linear(token_dim * num_heads, token_dim)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
query = self.to_query(x)
|
|
|
|
|
key = self.to_key(x)
|
|
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
query = torch.nn.functional.normalize(query, dim=-1) #BxNxD
|
|
|
|
|
key = torch.nn.functional.normalize(key, dim=-1) #BxNxD
|
2023-03-26 23:31:59 +04:00
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)
|
|
|
|
|
A = query_weight * self.scale_factor # BxNx1
|
2023-03-26 23:31:59 +04:00
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
A = torch.nn.functional.normalize(A, dim=1) # BxNx1
|
2023-03-26 23:31:59 +04:00
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
G = torch.sum(A * query, dim=1) # BxD
|
2023-03-26 23:31:59 +04:00
|
|
|
|
|
|
|
|
G = einops.repeat(
|
|
|
|
|
G, "b d -> b repeat d", repeat=key.shape[1]
|
2023-07-26 00:59:51 +04:00
|
|
|
) # BxNxD
|
2023-03-26 23:31:59 +04:00
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
out = self.Proj(G * key) + query #BxNxD
|
2023-03-26 23:31:59 +04:00
|
|
|
|
2023-07-26 00:59:51 +04:00
|
|
|
out = self.final(out) # BxNxD
|
2023-03-26 23:31:59 +04:00
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiftFormerLocalRepresentation(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
|
|
|
|
|
Input: tensor in shape [B, C, H, W]
|
|
|
|
|
Output: tensor in shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
|
|
|
|
|
self.norm = nn.BatchNorm2d(dim)
|
|
|
|
|
self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1)
|
|
|
|
|
self.act = nn.GELU()
|
|
|
|
|
self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
input = x
|
|
|
|
|
x = self.dwconv(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
x = self.pwconv1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.pwconv2(x)
|
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = input + self.drop_path(self.layer_scale * x)
|
|
|
|
|
else:
|
|
|
|
|
x = input + self.drop_path(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiftFormerEncoder(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block.
|
|
|
|
|
Input: tensor in shape [B, C, H, W]
|
|
|
|
|
Output: tensor in shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, mlp_ratio=4.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop=0., drop_path=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0.,
|
|
|
|
|
use_layer_scale=True)
|
|
|
|
|
self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1)
|
|
|
|
|
self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.use_layer_scale = use_layer_scale
|
|
|
|
|
if use_layer_scale:
|
|
|
|
|
self.layer_scale_1 = nn.Parameter(
|
|
|
|
|
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
|
|
|
|
self.layer_scale_2 = nn.Parameter(
|
|
|
|
|
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.local_representation(x)
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
if self.use_layer_scale:
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(
|
|
|
|
|
0, 3, 1, 2))
|
|
|
|
|
x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
x = x + self.drop_path(
|
|
|
|
|
self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2))
|
|
|
|
|
x = x + self.drop_path(self.linear(x))
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Stage(dim, index, layers, mlp_ratio=4.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop_rate=.0, drop_path_rate=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1):
|
|
|
|
|
"""
|
|
|
|
|
Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks.
|
|
|
|
|
Input: tensor in shape [B, C, H, W]
|
|
|
|
|
Output: tensor in shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
blocks = []
|
|
|
|
|
|
|
|
|
|
for block_idx in range(layers[index]):
|
|
|
|
|
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
|
|
|
|
|
|
|
|
|
if layers[index] - block_idx <= vit_num:
|
|
|
|
|
blocks.append(SwiftFormerEncoder(
|
|
|
|
|
dim, mlp_ratio=mlp_ratio,
|
|
|
|
|
act_layer=act_layer, drop_path=block_dpr,
|
|
|
|
|
use_layer_scale=use_layer_scale,
|
|
|
|
|
layer_scale_init_value=layer_scale_init_value))
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3))
|
|
|
|
|
|
|
|
|
|
blocks = nn.Sequential(*blocks)
|
|
|
|
|
return blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiftFormer(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, layers, embed_dims=None,
|
|
|
|
|
mlp_ratios=4, downsamples=None,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
num_classes=1000,
|
|
|
|
|
down_patch_size=3, down_stride=2, down_pad=1,
|
|
|
|
|
drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
use_layer_scale=True, layer_scale_init_value=1e-5,
|
|
|
|
|
fork_feat=False,
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
pretrained=None,
|
|
|
|
|
vit_num=1,
|
|
|
|
|
distillation=True,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if not fork_feat:
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.fork_feat = fork_feat
|
|
|
|
|
|
|
|
|
|
self.patch_embed = stem(3, embed_dims[0])
|
|
|
|
|
|
|
|
|
|
network = []
|
|
|
|
|
for i in range(len(layers)):
|
|
|
|
|
stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios,
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
drop_rate=drop_rate,
|
|
|
|
|
drop_path_rate=drop_path_rate,
|
|
|
|
|
use_layer_scale=use_layer_scale,
|
|
|
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
|
|
|
vit_num=vit_num)
|
|
|
|
|
network.append(stage)
|
|
|
|
|
if i >= len(layers) - 1:
|
|
|
|
|
break
|
|
|
|
|
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
|
|
|
|
# downsampling between two stages
|
|
|
|
|
network.append(
|
|
|
|
|
Embedding(
|
|
|
|
|
patch_size=down_patch_size, stride=down_stride,
|
|
|
|
|
padding=down_pad,
|
|
|
|
|
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.network = nn.ModuleList(network)
|
|
|
|
|
|
|
|
|
|
if self.fork_feat:
|
|
|
|
|
# add a norm layer for each output
|
|
|
|
|
self.out_indices = [0, 2, 4, 6]
|
|
|
|
|
for i_emb, i_layer in enumerate(self.out_indices):
|
|
|
|
|
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
|
|
|
|
|
layer = nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
layer = nn.BatchNorm2d(embed_dims[i_emb])
|
|
|
|
|
layer_name = f'norm{i_layer}'
|
|
|
|
|
self.add_module(layer_name, layer)
|
|
|
|
|
else:
|
|
|
|
|
# Classifier head
|
|
|
|
|
self.norm = nn.BatchNorm2d(embed_dims[-1])
|
|
|
|
|
self.head = nn.Linear(
|
|
|
|
|
embed_dims[-1], num_classes) if num_classes > 0 \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
self.dist = distillation
|
|
|
|
|
if self.dist:
|
|
|
|
|
self.dist_head = nn.Linear(
|
|
|
|
|
embed_dims[-1], num_classes) if num_classes > 0 \
|
|
|
|
|
else nn.Identity()
|
|
|
|
|
|
|
|
|
|
# self.apply(self.cls_init_weights)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
self.init_cfg = copy.deepcopy(init_cfg)
|
|
|
|
|
# load pre-trained model
|
|
|
|
|
if self.fork_feat and (
|
|
|
|
|
self.init_cfg is not None or pretrained is not None):
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
|
|
|
|
# init for mmdetection or mmsegmentation by loading
|
|
|
|
|
# imagenet pre-trained weights
|
|
|
|
|
def init_weights(self, pretrained=None):
|
|
|
|
|
logger = get_root_logger()
|
|
|
|
|
if self.init_cfg is None and pretrained is None:
|
|
|
|
|
logger.warn(f'No pre-trained weights for '
|
|
|
|
|
f'{self.__class__.__name__}, '
|
|
|
|
|
f'training start from scratch')
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
|
|
|
|
f'specify `Pretrained` in ' \
|
|
|
|
|
f'`init_cfg` in ' \
|
|
|
|
|
f'{self.__class__.__name__} '
|
|
|
|
|
if self.init_cfg is not None:
|
|
|
|
|
ckpt_path = self.init_cfg['checkpoint']
|
|
|
|
|
elif pretrained is not None:
|
|
|
|
|
ckpt_path = pretrained
|
|
|
|
|
|
|
|
|
|
ckpt = _load_checkpoint(
|
|
|
|
|
ckpt_path, logger=logger, map_location='cpu')
|
|
|
|
|
if 'state_dict' in ckpt:
|
|
|
|
|
_state_dict = ckpt['state_dict']
|
|
|
|
|
elif 'model' in ckpt:
|
|
|
|
|
_state_dict = ckpt['model']
|
|
|
|
|
else:
|
|
|
|
|
_state_dict = ckpt
|
|
|
|
|
|
|
|
|
|
state_dict = _state_dict
|
|
|
|
|
missing_keys, unexpected_keys = \
|
|
|
|
|
self.load_state_dict(state_dict, False)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, (nn.LayerNorm)):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
def forward_tokens(self, x):
|
|
|
|
|
outs = []
|
|
|
|
|
for idx, block in enumerate(self.network):
|
|
|
|
|
x = block(x)
|
|
|
|
|
if self.fork_feat and idx in self.out_indices:
|
|
|
|
|
norm_layer = getattr(self, f'norm{idx}')
|
|
|
|
|
x_out = norm_layer(x)
|
|
|
|
|
outs.append(x_out)
|
|
|
|
|
if self.fork_feat:
|
|
|
|
|
return outs
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = self.forward_tokens(x)
|
|
|
|
|
if self.fork_feat:
|
|
|
|
|
# Output features of four stages for dense prediction
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
if self.dist:
|
|
|
|
|
cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1))
|
|
|
|
|
if not self.training:
|
|
|
|
|
cls_out = (cls_out[0] + cls_out[1]) / 2
|
|
|
|
|
else:
|
2023-11-29 20:15:00 +08:00
|
|
|
cls_out = self.head(x.flatten(2).mean(-1))
|
2023-03-26 23:31:59 +04:00
|
|
|
# For image classification
|
|
|
|
|
return cls_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .95, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def SwiftFormer_XS(pretrained=False, **kwargs):
|
|
|
|
|
model = SwiftFormer(
|
|
|
|
|
layers=SwiftFormer_depth['XS'],
|
|
|
|
|
embed_dims=SwiftFormer_width['XS'],
|
|
|
|
|
downsamples=[True, True, True, True],
|
|
|
|
|
vit_num=1,
|
|
|
|
|
**kwargs)
|
|
|
|
|
model.default_cfg = _cfg(crop_pct=0.9)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def SwiftFormer_S(pretrained=False, **kwargs):
|
|
|
|
|
model = SwiftFormer(
|
|
|
|
|
layers=SwiftFormer_depth['S'],
|
|
|
|
|
embed_dims=SwiftFormer_width['S'],
|
|
|
|
|
downsamples=[True, True, True, True],
|
|
|
|
|
vit_num=1,
|
|
|
|
|
**kwargs)
|
|
|
|
|
model.default_cfg = _cfg(crop_pct=0.9)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def SwiftFormer_L1(pretrained=False, **kwargs):
|
|
|
|
|
model = SwiftFormer(
|
|
|
|
|
layers=SwiftFormer_depth['l1'],
|
|
|
|
|
embed_dims=SwiftFormer_width['l1'],
|
|
|
|
|
downsamples=[True, True, True, True],
|
|
|
|
|
vit_num=1,
|
|
|
|
|
**kwargs)
|
|
|
|
|
model.default_cfg = _cfg(crop_pct=0.9)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def SwiftFormer_L3(pretrained=False, **kwargs):
|
|
|
|
|
model = SwiftFormer(
|
|
|
|
|
layers=SwiftFormer_depth['l3'],
|
|
|
|
|
embed_dims=SwiftFormer_width['l3'],
|
|
|
|
|
downsamples=[True, True, True, True],
|
|
|
|
|
vit_num=1,
|
|
|
|
|
**kwargs)
|
|
|
|
|
model.default_cfg = _cfg(crop_pct=0.9)
|
|
|
|
|
return model
|
2023-07-26 00:59:51 +04:00
|
|
|
|