Files
asmo_vhead/models/swiftformer_temporal.py

481 lines
20 KiB
Python
Raw Normal View History

"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder with residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True # 启用bias因为移除了BN
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x):
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class DecoderBlockWithSkip(nn.Module):
"""Decoder block with skip connection support"""
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 总输入通道 = 输入通道 + skip通道
total_in_channels = in_channels + skip_channels
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if total_in_channels != out_channels or stride != 1:
if stride == 1:
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x, skip_feature=None):
# 如果有skip feature将其与输入拼接
if skip_feature is not None:
# 确保skip特征的空间尺寸与x匹配
if skip_feature.shape[2:] != x.shape[2:]:
# 使用双线性插值进行上采样或下采样
skip_feature = torch.nn.functional.interpolate(
skip_feature,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
x = torch.cat([x, skip_feature], dim=1)
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction with better upsampling strategy"""
def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList()
if use_skip:
# 使用支持skip connections的block
# 第一个block从bottleneck到stage4使用大步长stride=4skip来自stage3
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[0], decoder_dims[1],
skip_channels=embed_dims[3], # stage3的通道数
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
# 第二个blockstage4到stage3stride=2skip来自stage2
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[1], decoder_dims[2],
skip_channels=embed_dims[2], # stage2的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstage3到stage2stride=2skip来自stage1
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[2], decoder_dims[3],
skip_channels=embed_dims[1], # stage1的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstage2到stage1stride=2skip来自stage0
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[3], 64, # 输出到64通道
skip_channels=embed_dims[0], # stage0的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
else:
# 使用普通的DecoderBlock第一个block使用大步长
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个block增加到64通道
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 改进的最终输出层:不使用反卷积,只进行特征精炼
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
# 移除Tanh让输出在任意范围由损失函数和归一化处理
)
def forward(self, x, skip_features=None):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches encoder dims
"""
if self.use_skip:
if skip_features is None:
raise ValueError("skip_features must be provided when use_skip=True")
# 确保有4个skip features
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
# 反转顺序以匹配解码器stage3, stage2, stage1, stage0
skip_features = skip_features[::-1]
# 调整skip特征的尺寸以匹配新的上采样策略
adjusted_skip_features = []
for i, skip in enumerate(skip_features):
if skip is not None:
# 计算目标尺寸4, 2, 2, 2倍上采样
upsample_factors = [4, 2, 2, 2]
target_height = x.shape[2] * upsample_factors[i]
target_width = x.shape[3] * upsample_factors[i]
if skip.shape[2:] != (target_height, target_width):
skip = torch.nn.functional.interpolate(
skip,
size=(target_height, target_width),
mode='bilinear',
align_corners=False
)
adjusted_skip_features.append(skip)
# 四个block使用skip connections
for i in range(4):
x = self.blocks[i](x, adjusted_skip_features[i])
else:
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
use_skip=True, # 新增是否使用skip connections
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_skip = use_skip # 保存skip connections设置
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1,
use_skip=use_skip # 传递skip connections设置
)
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
)
else:
self.representation_head = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU/LeakyReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features or self.use_skip:
features = []
stage_idx = 0
for idx, block in enumerate(self.network):
x = block(x)
# 收集每个stage的输出stage0, stage1, stage2, stage3
# 根据SwiftFormer结构stage在索引0,2,4,6位置
if idx in [0, 2, 4, 6]:
features.append(x)
stage_idx += 1
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
"""
# Encode
x = self.patch_embed(x)
if self.return_features or self.use_skip:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
if self.use_skip:
# 提取用于skip connections的特征
# features包含所有stage的输出我们需要stage0, stage1, stage2, stage3
# 根据SwiftFormer结构应该有4个stage特征
if len(features) >= 4:
# 取四个stage的特征stage0, stage1, stage2, stage3
skip_features = [features[0], features[1], features[2], features[3]]
else:
# 如果特征不够,使用可用的特征
skip_features = features[:4]
# 如果特征仍然不够使用None填充
while len(skip_features) < 4:
skip_features.append(None)
pred_frame = self.decoder(x, skip_features)
else:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)