2026-01-07 11:03:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
|
|
|
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from .swiftformer import (
|
|
|
|
|
|
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
|
|
|
|
|
|
stem, Embedding, Stage
|
|
|
|
|
|
)
|
2026-01-08 09:43:23 +08:00
|
|
|
|
from timm.layers import DropPath, trunc_normal_
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module):
|
2026-01-09 18:23:45 +08:00
|
|
|
|
"""Upsampling block for frame prediction decoder with residual connections"""
|
2026-01-07 11:03:33 +08:00
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
|
|
|
|
|
super().__init__()
|
2026-01-09 18:23:45 +08:00
|
|
|
|
# 主路径:反卷积 + 两个卷积层
|
|
|
|
|
|
self.conv_transpose = nn.ConvTranspose2d(
|
2026-01-07 11:03:33 +08:00
|
|
|
|
in_channels, out_channels,
|
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
|
stride=stride,
|
|
|
|
|
|
padding=padding,
|
|
|
|
|
|
output_padding=output_padding,
|
2026-01-15 21:12:27 +08:00
|
|
|
|
bias=False # 禁用bias,因为使用BN
|
2026-01-07 11:03:33 +08:00
|
|
|
|
)
|
2026-01-15 21:12:27 +08:00
|
|
|
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
2026-01-15 21:12:27 +08:00
|
|
|
|
kernel_size=3, padding=1, bias=False)
|
|
|
|
|
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
2026-01-15 21:12:27 +08:00
|
|
|
|
kernel_size=3, padding=1, bias=False)
|
|
|
|
|
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 残差路径:如果需要改变通道数或空间尺寸
|
|
|
|
|
|
self.shortcut = nn.Identity()
|
|
|
|
|
|
if in_channels != out_channels or stride != 1:
|
|
|
|
|
|
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
|
|
|
|
|
if stride == 1:
|
2026-01-15 21:12:27 +08:00
|
|
|
|
self.shortcut = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
|
|
|
|
|
nn.BatchNorm2d(out_channels)
|
|
|
|
|
|
)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
else:
|
2026-01-15 21:12:27 +08:00
|
|
|
|
self.shortcut = nn.Sequential(
|
|
|
|
|
|
nn.ConvTranspose2d(
|
|
|
|
|
|
in_channels, out_channels,
|
|
|
|
|
|
kernel_size=1,
|
|
|
|
|
|
stride=stride,
|
|
|
|
|
|
padding=0,
|
|
|
|
|
|
output_padding=output_padding,
|
|
|
|
|
|
bias=False
|
|
|
|
|
|
),
|
|
|
|
|
|
nn.BatchNorm2d(out_channels)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 使用ReLU激活函数
|
|
|
|
|
|
self.activation = nn.ReLU(inplace=True)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 初始化权重
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self):
|
|
|
|
|
|
# 初始化反卷积层
|
2026-01-15 21:12:27 +08:00
|
|
|
|
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 初始化卷积层
|
2026-01-15 21:12:27 +08:00
|
|
|
|
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
|
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 初始化shortcut
|
|
|
|
|
|
if not isinstance(self.shortcut, nn.Identity):
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# shortcut现在是Sequential,需要初始化其中的卷积层
|
|
|
|
|
|
for module in self.shortcut:
|
|
|
|
|
|
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
|
|
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化BN层(使用默认初始化)
|
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
|
nn.init.constant_(m.weight, 1)
|
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
2026-01-09 18:23:45 +08:00
|
|
|
|
identity = self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
|
|
# 主路径
|
|
|
|
|
|
x = self.conv_transpose(x)
|
2026-01-15 21:12:27 +08:00
|
|
|
|
x = self.bn1(x)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
x = self.activation(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.conv1(x)
|
2026-01-15 21:12:27 +08:00
|
|
|
|
x = self.bn2(x)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
x = self.activation(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.conv2(x)
|
2026-01-15 21:12:27 +08:00
|
|
|
|
x = self.bn3(x)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 残差连接
|
|
|
|
|
|
x = x + identity
|
|
|
|
|
|
x = self.activation(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-07 11:03:33 +08:00
|
|
|
|
class FramePredictionDecoder(nn.Module):
|
2026-01-11 13:25:34 +08:00
|
|
|
|
"""Improved decoder for frame prediction"""
|
|
|
|
|
|
def __init__(self, embed_dims, output_channels=1):
|
2026-01-07 11:03:33 +08:00
|
|
|
|
super().__init__()
|
|
|
|
|
|
# Reverse the embed_dims for decoder
|
|
|
|
|
|
decoder_dims = embed_dims[::-1]
|
|
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList()
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 调整顺序:将stride=4放在倒数第二的位置
|
|
|
|
|
|
# 第一个block:stride=2 (220 -> 112)
|
2026-01-11 13:25:34 +08:00
|
|
|
|
self.blocks.append(DecoderBlock(
|
|
|
|
|
|
decoder_dims[0], decoder_dims[1],
|
2026-01-15 21:12:27 +08:00
|
|
|
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
2026-01-11 13:25:34 +08:00
|
|
|
|
))
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 第二个block:stride=2 (112 -> 56)
|
2026-01-11 13:25:34 +08:00
|
|
|
|
self.blocks.append(DecoderBlock(
|
|
|
|
|
|
decoder_dims[1], decoder_dims[2],
|
|
|
|
|
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
|
|
|
|
|
))
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 第三个block:stride=2 (56 -> 48)
|
2026-01-11 13:25:34 +08:00
|
|
|
|
self.blocks.append(DecoderBlock(
|
|
|
|
|
|
decoder_dims[2], decoder_dims[3],
|
|
|
|
|
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
|
|
|
|
|
))
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 第四个block:stride=4 (48 -> 64),放在倒数第二的位置
|
2026-01-11 13:25:34 +08:00
|
|
|
|
self.blocks.append(DecoderBlock(
|
|
|
|
|
|
decoder_dims[3], 64,
|
2026-01-15 21:12:27 +08:00
|
|
|
|
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
2026-01-11 13:25:34 +08:00
|
|
|
|
))
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
self.final_block = nn.Sequential(
|
|
|
|
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
2026-01-15 21:12:27 +08:00
|
|
|
|
nn.ReLU(inplace=True),
|
2026-01-09 18:23:45 +08:00
|
|
|
|
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
2026-01-15 21:12:27 +08:00
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
|
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
|
|
|
|
|
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
|
2026-01-09 18:23:45 +08:00
|
|
|
|
)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
2026-01-11 13:25:34 +08:00
|
|
|
|
def forward(self, x):
|
2026-01-07 11:03:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
|
|
|
|
|
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
|
|
|
|
|
"""
|
2026-01-11 13:25:34 +08:00
|
|
|
|
# 不使用skip connections
|
|
|
|
|
|
for i in range(4):
|
|
|
|
|
|
x = self.blocks[i](x)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 最终输出层:只进行特征精炼,不上采样
|
|
|
|
|
|
x = self.final_block(x)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwiftFormerTemporal(nn.Module):
|
|
|
|
|
|
"""
|
|
|
|
|
|
SwiftFormer with temporal input for frame prediction.
|
|
|
|
|
|
Input: [B, num_frames, H, W] (Y channel only)
|
2026-01-08 09:43:23 +08:00
|
|
|
|
Output: predicted frame [B, 1, H, W] and optional representation
|
2026-01-07 11:03:33 +08:00
|
|
|
|
"""
|
2026-01-09 18:23:45 +08:00
|
|
|
|
def __init__(self,
|
2026-01-07 11:03:33 +08:00
|
|
|
|
model_name='XS',
|
|
|
|
|
|
num_frames=3,
|
|
|
|
|
|
use_decoder=True,
|
|
|
|
|
|
**kwargs):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
# Get model configuration
|
|
|
|
|
|
layers = SwiftFormer_depth[model_name]
|
|
|
|
|
|
embed_dims = SwiftFormer_width[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
# Store configuration
|
|
|
|
|
|
self.num_frames = num_frames
|
|
|
|
|
|
self.use_decoder = use_decoder
|
|
|
|
|
|
|
|
|
|
|
|
# Modify stem to accept multiple frames (only Y channel)
|
|
|
|
|
|
in_channels = num_frames
|
|
|
|
|
|
self.patch_embed = stem(in_channels, embed_dims[0])
|
|
|
|
|
|
|
|
|
|
|
|
# Build encoder network (same as SwiftFormer)
|
|
|
|
|
|
network = []
|
|
|
|
|
|
for i in range(len(layers)):
|
|
|
|
|
|
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
|
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
|
drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
|
use_layer_scale=True,
|
|
|
|
|
|
layer_scale_init_value=1e-5,
|
|
|
|
|
|
vit_num=1)
|
|
|
|
|
|
network.append(stage)
|
|
|
|
|
|
if i >= len(layers) - 1:
|
|
|
|
|
|
break
|
|
|
|
|
|
if embed_dims[i] != embed_dims[i + 1]:
|
|
|
|
|
|
network.append(
|
|
|
|
|
|
Embedding(
|
|
|
|
|
|
patch_size=3, stride=2, padding=1,
|
|
|
|
|
|
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.network = nn.ModuleList(network)
|
|
|
|
|
|
self.norm = nn.BatchNorm2d(embed_dims[-1])
|
|
|
|
|
|
|
|
|
|
|
|
# Frame prediction decoder
|
|
|
|
|
|
if use_decoder:
|
2026-01-09 18:23:45 +08:00
|
|
|
|
self.decoder = FramePredictionDecoder(
|
|
|
|
|
|
embed_dims,
|
2026-01-11 13:25:34 +08:00
|
|
|
|
output_channels=1
|
2026-01-09 18:23:45 +08:00
|
|
|
|
)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
2026-01-15 21:12:27 +08:00
|
|
|
|
# 使用Kaiming初始化,适合ReLU
|
|
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
2026-01-09 18:23:45 +08:00
|
|
|
|
if m.bias is not None:
|
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
elif isinstance(m, nn.ConvTranspose2d):
|
|
|
|
|
|
# 反卷积层使用特定的初始化
|
2026-01-15 21:12:27 +08:00
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
2026-01-07 11:03:33 +08:00
|
|
|
|
if m.bias is not None:
|
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
2026-01-09 18:23:45 +08:00
|
|
|
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
2026-01-07 11:03:33 +08:00
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
def forward_tokens(self, x):
|
2026-01-15 21:12:27 +08:00
|
|
|
|
for block in self.network:
|
|
|
|
|
|
x = block(x)
|
|
|
|
|
|
return x
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
|
|
|
|
|
x: input frames of shape [B, num_frames, H, W]
|
|
|
|
|
|
Returns:
|
2026-01-15 21:12:27 +08:00
|
|
|
|
pred_frame: predicted frame [B, 1, H, W] (or None)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
# Encode
|
|
|
|
|
|
x = self.patch_embed(x)
|
2026-01-15 21:12:27 +08:00
|
|
|
|
x = self.forward_tokens(x)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
|
|
|
|
|
|
# Decode to frame
|
|
|
|
|
|
pred_frame = None
|
|
|
|
|
|
if self.use_decoder:
|
2026-01-11 13:25:34 +08:00
|
|
|
|
pred_frame = self.decoder(x)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
2026-01-15 21:12:27 +08:00
|
|
|
|
return pred_frame
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Factory functions for different model sizes
|
2026-01-11 13:25:34 +08:00
|
|
|
|
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
|
|
|
|
|
|
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
2026-01-11 13:25:34 +08:00
|
|
|
|
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
|
|
|
|
|
|
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
2026-01-11 13:25:34 +08:00
|
|
|
|
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
|
|
|
|
|
|
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
|
2026-01-07 11:03:33 +08:00
|
|
|
|
|
2026-01-11 13:25:34 +08:00
|
|
|
|
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
|
|
|
|
|
|
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
|