Files
asmo_vhead/models/swiftformer.py

503 lines
17 KiB
Python
Raw Normal View History

2023-03-26 23:31:59 +04:00
"""
SwiftFormer
"""
import os
import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
import einops
SwiftFormer_width = {
'XS': [48, 56, 112, 220],
'S': [48, 64, 168, 224],
'l1': [48, 96, 192, 384],
'l3': [64, 128, 320, 512],
}
SwiftFormer_depth = {
'XS': [3, 3, 6, 4],
'S': [3, 3, 9, 6],
'l1': [4, 3, 10, 5],
'l3': [4, 4, 12, 6],
}
def stem(in_chs, out_chs):
"""
Stem Layer that is implemented by two layers of conv.
Output: sequence of layers with final shape of [B, C, H/4, W/4]
"""
return nn.Sequential(
nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_chs // 2),
nn.ReLU(),
nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_chs),
nn.ReLU(), )
class Embedding(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class ConvEncoder(nn.Module):
"""
Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
Output: tensor with shape [B, C, H, W]
"""
def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
self.norm = nn.BatchNorm2d(dim)
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.use_layer_scale:
x = input + self.drop_path(self.layer_scale * x)
else:
x = input + self.drop_path(x)
return x
class Mlp(nn.Module):
"""
Implementation of MLP layer with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
Output: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm1 = nn.BatchNorm2d(in_features)
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.norm1(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class EfficientAdditiveAttnetion(nn.Module):
"""
Efficient Additive Attention module for SwiftFormer.
2023-07-26 00:59:51 +04:00
Input: tensor in shape [B, N, D]
Output: tensor in shape [B, N, D]
2023-03-26 23:31:59 +04:00
"""
def __init__(self, in_dims=512, token_dim=256, num_heads=2):
super().__init__()
self.to_query = nn.Linear(in_dims, token_dim * num_heads)
self.to_key = nn.Linear(in_dims, token_dim * num_heads)
self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
self.scale_factor = token_dim ** -0.5
self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
self.final = nn.Linear(token_dim * num_heads, token_dim)
def forward(self, x):
query = self.to_query(x)
key = self.to_key(x)
2023-07-26 00:59:51 +04:00
query = torch.nn.functional.normalize(query, dim=-1) #BxNxD
key = torch.nn.functional.normalize(key, dim=-1) #BxNxD
2023-03-26 23:31:59 +04:00
2023-07-26 00:59:51 +04:00
query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)
A = query_weight * self.scale_factor # BxNx1
2023-03-26 23:31:59 +04:00
2023-07-26 00:59:51 +04:00
A = torch.nn.functional.normalize(A, dim=1) # BxNx1
2023-03-26 23:31:59 +04:00
2023-07-26 00:59:51 +04:00
G = torch.sum(A * query, dim=1) # BxD
2023-03-26 23:31:59 +04:00
G = einops.repeat(
G, "b d -> b repeat d", repeat=key.shape[1]
2023-07-26 00:59:51 +04:00
) # BxNxD
2023-03-26 23:31:59 +04:00
2023-07-26 00:59:51 +04:00
out = self.Proj(G * key) + query #BxNxD
2023-03-26 23:31:59 +04:00
2023-07-26 00:59:51 +04:00
out = self.final(out) # BxNxD
2023-03-26 23:31:59 +04:00
return out
class SwiftFormerLocalRepresentation(nn.Module):
"""
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
self.norm = nn.BatchNorm2d(dim)
self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.use_layer_scale:
x = input + self.drop_path(self.layer_scale * x)
else:
x = input + self.drop_path(x)
return x
class SwiftFormerEncoder(nn.Module):
"""
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
def __init__(self, dim, mlp_ratio=4.,
act_layer=nn.GELU,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0.,
use_layer_scale=True)
self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1)
self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
def forward(self, x):
x = self.local_representation(x)
B, C, H, W = x.shape
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(
0, 3, 1, 2))
x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
else:
x = x + self.drop_path(
self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2))
x = x + self.drop_path(self.linear(x))
return x
def Stage(dim, index, layers, mlp_ratio=4.,
act_layer=nn.GELU,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1):
"""
Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
if layers[index] - block_idx <= vit_num:
blocks.append(SwiftFormerEncoder(
dim, mlp_ratio=mlp_ratio,
act_layer=act_layer, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value))
else:
blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3))
blocks = nn.Sequential(*blocks)
return blocks
class SwiftFormer(nn.Module):
def __init__(self, layers, embed_dims=None,
mlp_ratios=4, downsamples=None,
act_layer=nn.GELU,
num_classes=1000,
down_patch_size=3, down_stride=2, down_pad=1,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5,
fork_feat=False,
init_cfg=None,
pretrained=None,
vit_num=1,
distillation=True,
**kwargs):
super().__init__()
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat
self.patch_embed = stem(3, embed_dims[0])
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios,
act_layer=act_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
vit_num=vit_num)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
Embedding(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
layer = nn.Identity()
else:
layer = nn.BatchNorm2d(embed_dims[i_emb])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
# Classifier head
self.norm = nn.BatchNorm2d(embed_dims[-1])
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
self.dist = distillation
if self.dist:
self.dist_head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
# self.apply(self.cls_init_weights)
self.apply(self._init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model
if self.fork_feat and (
self.init_cfg is not None or pretrained is not None):
self.init_weights()
# init for mmdetection or mmsegmentation by loading
# imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
return outs
return x
def forward(self, x):
x = self.patch_embed(x)
x = self.forward_tokens(x)
if self.fork_feat:
# Output features of four stages for dense prediction
return x
x = self.norm(x)
if self.dist:
cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1))
if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2
else:
cls_out = self.head(x.mean(-2))
# For image classification
return cls_out
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}
@register_model
def SwiftFormer_XS(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['XS'],
embed_dims=SwiftFormer_width['XS'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_S(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['S'],
embed_dims=SwiftFormer_width['S'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_L1(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['l1'],
embed_dims=SwiftFormer_width['l1'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_L3(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['l3'],
embed_dims=SwiftFormer_width['l3'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
2023-07-26 00:59:51 +04:00