Files
asmo_vhead/models/swiftformer_temporal.py

232 lines
8.2 KiB
Python
Raw Normal View History

"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False # 禁用bias因为使用BN
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
2026-01-11 13:25:34 +08:00
"""Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1):
super().__init__()
# Define decoder dimensions independently (no skip connections)
start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList()
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
2026-01-11 13:25:34 +08:00
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
2026-01-11 13:25:34 +08:00
))
# 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
2026-01-11 13:25:34 +08:00
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
2026-01-11 13:25:34 +08:00
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
2026-01-11 13:25:34 +08:00
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
2026-01-11 13:25:34 +08:00
))
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
2026-01-11 13:25:34 +08:00
def forward(self, x):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
"""
2026-01-11 13:25:34 +08:00
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(
embed_dims,
2026-01-11 13:25:34 +08:00
output_channels=1
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
pred_frame: predicted frame [B, 1, H, W] (or None)
"""
# Encode
x = self.patch_embed(x)
x = self.forward_tokens(x)
x = self.norm(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
2026-01-11 13:25:34 +08:00
pred_frame = self.decoder(x)
return pred_frame
# Factory functions for different model sizes
2026-01-11 13:25:34 +08:00
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
2026-01-11 13:25:34 +08:00
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
2026-01-11 13:25:34 +08:00
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
2026-01-11 13:25:34 +08:00
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)