""" SwiftFormer """ import os import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.layers import DropPath, trunc_normal_ from timm.models.registry import register_model from timm.models.layers.helpers import to_2tuple import einops SwiftFormer_width = { 'XS': [48, 56, 112, 220], 'S': [48, 64, 168, 224], 'l1': [48, 96, 192, 384], 'l3': [64, 128, 320, 512], } SwiftFormer_depth = { 'XS': [3, 3, 6, 4], 'S': [3, 3, 9, 6], 'l1': [4, 3, 10, 5], 'l3': [4, 4, 12, 6], } def stem(in_chs, out_chs): """ Stem Layer that is implemented by two layers of conv. Output: sequence of layers with final shape of [B, C, H/4, W/4] """ return nn.Sequential( nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_chs // 2), nn.ReLU(), nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_chs), nn.ReLU(), ) class Embedding(nn.Module): """ Patch Embedding that is implemented by a layer of conv. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H/stride, W/stride] """ def __init__(self, patch_size=16, stride=16, padding=0, in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d): super().__init__() patch_size = to_2tuple(patch_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=padding) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): x = self.proj(x) x = self.norm(x) return x class ConvEncoder(nn.Module): """ Implementation of ConvEncoder with 3*3 and 1*1 convolutions. Input: tensor with shape [B, C, H, W] Output: tensor with shape [B, C, H, W] """ def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) self.norm = nn.BatchNorm2d(dim) self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) self.act = nn.GELU() self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): input = x x = self.dwconv(x) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.use_layer_scale: x = input + self.drop_path(self.layer_scale * x) else: x = input + self.drop_path(x) return x class Mlp(nn.Module): """ Implementation of MLP layer with 1*1 convolutions. Input: tensor with shape [B, C, H, W] Output: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.norm1 = nn.BatchNorm2d(in_features) self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.norm1(x) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class EfficientAdditiveAttnetion(nn.Module): """ Efficient Additive Attention module for SwiftFormer. Input: tensor in shape [B, N, D] Output: tensor in shape [B, N, D] """ def __init__(self, in_dims=512, token_dim=256, num_heads=2): super().__init__() self.to_query = nn.Linear(in_dims, token_dim * num_heads) self.to_key = nn.Linear(in_dims, token_dim * num_heads) self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1)) self.scale_factor = token_dim ** -0.5 self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads) self.final = nn.Linear(token_dim * num_heads, token_dim) def forward(self, x): query = self.to_query(x) key = self.to_key(x) query = torch.nn.functional.normalize(query, dim=-1) #BxNxD key = torch.nn.functional.normalize(key, dim=-1) #BxNxD query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1) A = query_weight * self.scale_factor # BxNx1 A = torch.nn.functional.normalize(A, dim=1) # BxNx1 G = torch.sum(A * query, dim=1) # BxD G = einops.repeat( G, "b d -> b repeat d", repeat=key.shape[1] ) # BxNxD out = self.Proj(G * key) + query #BxNxD out = self.final(out) # BxNxD return out class SwiftFormerLocalRepresentation(nn.Module): """ Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H, W] """ def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True): super().__init__() self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) self.norm = nn.BatchNorm2d(dim) self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1) self.act = nn.GELU() self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1) self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): input = x x = self.dwconv(x) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.use_layer_scale: x = input + self.drop_path(self.layer_scale * x) else: x = input + self.drop_path(x) return x class SwiftFormerEncoder(nn.Module): """ SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H, W] """ def __init__(self, dim, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0., use_layer_scale=True) self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1) self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. \ else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) def forward(self, x): x = self.local_representation(x) B, C, H, W = x.shape if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute( 0, 3, 1, 2)) x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) else: x = x + self.drop_path( self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2)) x = x + self.drop_path(self.linear(x)) return x def Stage(dim, index, layers, mlp_ratio=4., act_layer=nn.GELU, drop_rate=.0, drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1): """ Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks. Input: tensor in shape [B, C, H, W] Output: tensor in shape [B, C, H, W] """ blocks = [] for block_idx in range(layers[index]): block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) if layers[index] - block_idx <= vit_num: blocks.append(SwiftFormerEncoder( dim, mlp_ratio=mlp_ratio, act_layer=act_layer, drop_path=block_dpr, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value)) else: blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3)) blocks = nn.Sequential(*blocks) return blocks class SwiftFormer(nn.Module): def __init__(self, layers, embed_dims=None, mlp_ratios=4, downsamples=None, act_layer=nn.GELU, num_classes=1000, down_patch_size=3, down_stride=2, down_pad=1, drop_rate=0., drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, fork_feat=False, init_cfg=None, pretrained=None, vit_num=1, distillation=True, **kwargs): super().__init__() if not fork_feat: self.num_classes = num_classes self.fork_feat = fork_feat self.patch_embed = stem(3, embed_dims[0]) network = [] for i in range(len(layers)): stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios, act_layer=act_layer, drop_rate=drop_rate, drop_path_rate=drop_path_rate, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, vit_num=vit_num) network.append(stage) if i >= len(layers) - 1: break if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: # downsampling between two stages network.append( Embedding( patch_size=down_patch_size, stride=down_stride, padding=down_pad, in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] ) ) self.network = nn.ModuleList(network) if self.fork_feat: # add a norm layer for each output self.out_indices = [0, 2, 4, 6] for i_emb, i_layer in enumerate(self.out_indices): if i_emb == 0 and os.environ.get('FORK_LAST3', None): layer = nn.Identity() else: layer = nn.BatchNorm2d(embed_dims[i_emb]) layer_name = f'norm{i_layer}' self.add_module(layer_name, layer) else: # Classifier head self.norm = nn.BatchNorm2d(embed_dims[-1]) self.head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() self.dist = distillation if self.dist: self.dist_head = nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 \ else nn.Identity() # self.apply(self.cls_init_weights) self.apply(self._init_weights) self.init_cfg = copy.deepcopy(init_cfg) # load pre-trained model if self.fork_feat and ( self.init_cfg is not None or pretrained is not None): self.init_weights() # init for mmdetection or mmsegmentation by loading # imagenet pre-trained weights def init_weights(self, pretrained=None): logger = get_root_logger() if self.init_cfg is None and pretrained is None: logger.warn(f'No pre-trained weights for ' f'{self.__class__.__name__}, ' f'training start from scratch') pass else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \ f'`init_cfg` in ' \ f'{self.__class__.__name__} ' if self.init_cfg is not None: ckpt_path = self.init_cfg['checkpoint'] elif pretrained is not None: ckpt_path = pretrained ckpt = _load_checkpoint( ckpt_path, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt state_dict = _state_dict missing_keys, unexpected_keys = \ self.load_state_dict(state_dict, False) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): outs = [] for idx, block in enumerate(self.network): x = block(x) if self.fork_feat and idx in self.out_indices: norm_layer = getattr(self, f'norm{idx}') x_out = norm_layer(x) outs.append(x_out) if self.fork_feat: return outs return x def forward(self, x): x = self.patch_embed(x) x = self.forward_tokens(x) if self.fork_feat: # Output features of four stages for dense prediction return x x = self.norm(x) if self.dist: cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1)) if not self.training: cls_out = (cls_out[0] + cls_out[1]) / 2 else: cls_out = self.head(x.mean(-2)) # For image classification return cls_out def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', **kwargs } @register_model def SwiftFormer_XS(pretrained=False, **kwargs): model = SwiftFormer( layers=SwiftFormer_depth['XS'], embed_dims=SwiftFormer_width['XS'], downsamples=[True, True, True, True], vit_num=1, **kwargs) model.default_cfg = _cfg(crop_pct=0.9) return model @register_model def SwiftFormer_S(pretrained=False, **kwargs): model = SwiftFormer( layers=SwiftFormer_depth['S'], embed_dims=SwiftFormer_width['S'], downsamples=[True, True, True, True], vit_num=1, **kwargs) model.default_cfg = _cfg(crop_pct=0.9) return model @register_model def SwiftFormer_L1(pretrained=False, **kwargs): model = SwiftFormer( layers=SwiftFormer_depth['l1'], embed_dims=SwiftFormer_width['l1'], downsamples=[True, True, True, True], vit_num=1, **kwargs) model.default_cfg = _cfg(crop_pct=0.9) return model @register_model def SwiftFormer_L3(pretrained=False, **kwargs): model = SwiftFormer( layers=SwiftFormer_depth['l3'], embed_dims=SwiftFormer_width['l3'], downsamples=[True, True, True, True], vit_num=1, **kwargs) model.default_cfg = _cfg(crop_pct=0.9) return model