""" SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction """ import torch import torch.nn as nn from .swiftformer import ( SwiftFormer, SwiftFormer_depth, SwiftFormer_width, stem, Embedding, Stage ) from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): """Upsampling block for frame prediction decoder with residual connections""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() # 主路径:反卷积 + 两个卷积层 self.conv_transpose = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True # 启用bias,因为移除了BN ) self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True) # 残差路径:如果需要改变通道数或空间尺寸 self.shortcut = nn.Identity() if in_channels != out_channels or stride != 1: # 使用1x1卷积调整通道数,如果需要上采样则使用反卷积 if stride == 1: self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) else: self.shortcut = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=1, stride=stride, padding=0, output_padding=output_padding, bias=True ) # 使用LeakyReLU避免死亡神经元 self.activation = nn.LeakyReLU(0.2, inplace=True) # 初始化权重 self._init_weights() def _init_weights(self): # 初始化反卷积层 nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv_transpose.bias is not None: nn.init.constant_(self.conv_transpose.bias, 0) # 初始化卷积层 nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv1.bias is not None: nn.init.constant_(self.conv1.bias, 0) nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv2.bias is not None: nn.init.constant_(self.conv2.bias, 0) # 初始化shortcut if not isinstance(self.shortcut, nn.Identity): if isinstance(self.shortcut, nn.Conv2d): nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') elif isinstance(self.shortcut, nn.ConvTranspose2d): nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') if self.shortcut.bias is not None: nn.init.constant_(self.shortcut.bias, 0) def forward(self, x): identity = self.shortcut(x) # 主路径 x = self.conv_transpose(x) x = self.activation(x) x = self.conv1(x) x = self.activation(x) x = self.conv2(x) # 残差连接 x = x + identity x = self.activation(x) return x class DecoderBlockWithSkip(nn.Module): """Decoder block with skip connection support""" def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() # 总输入通道 = 输入通道 + skip通道 total_in_channels = in_channels + skip_channels # 主路径:反卷积 + 两个卷积层 self.conv_transpose = nn.ConvTranspose2d( total_in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True ) self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True) # 残差路径:如果需要改变通道数或空间尺寸 self.shortcut = nn.Identity() if total_in_channels != out_channels or stride != 1: if stride == 1: self.shortcut = nn.Conv2d(total_in_channels, out_channels, kernel_size=1, bias=True) else: self.shortcut = nn.ConvTranspose2d( total_in_channels, out_channels, kernel_size=1, stride=stride, padding=0, output_padding=output_padding, bias=True ) # 使用LeakyReLU避免死亡神经元 self.activation = nn.LeakyReLU(0.2, inplace=True) # 初始化权重 self._init_weights() def _init_weights(self): # 初始化反卷积层 nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv_transpose.bias is not None: nn.init.constant_(self.conv_transpose.bias, 0) # 初始化卷积层 nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv1.bias is not None: nn.init.constant_(self.conv1.bias, 0) nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') if self.conv2.bias is not None: nn.init.constant_(self.conv2.bias, 0) # 初始化shortcut if not isinstance(self.shortcut, nn.Identity): if isinstance(self.shortcut, nn.Conv2d): nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') elif isinstance(self.shortcut, nn.ConvTranspose2d): nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') if self.shortcut.bias is not None: nn.init.constant_(self.shortcut.bias, 0) def forward(self, x, skip_feature=None): # 如果有skip feature,将其与输入拼接 if skip_feature is not None: # 确保skip特征的空间尺寸与x匹配 if skip_feature.shape[2:] != x.shape[2:]: # 使用双线性插值进行上采样或下采样 skip_feature = torch.nn.functional.interpolate( skip_feature, size=x.shape[2:], mode='bilinear', align_corners=False ) x = torch.cat([x, skip_feature], dim=1) identity = self.shortcut(x) # 主路径 x = self.conv_transpose(x) x = self.activation(x) x = self.conv1(x) x = self.activation(x) x = self.conv2(x) # 残差连接 x = x + identity x = self.activation(x) return x class FramePredictionDecoder(nn.Module): """Improved decoder for frame prediction with better upsampling strategy""" def __init__(self, embed_dims, output_channels=1, use_skip=False): super().__init__() self.use_skip = use_skip # Reverse the embed_dims for decoder decoder_dims = embed_dims[::-1] self.blocks = nn.ModuleList() if use_skip: # 使用支持skip connections的block # 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3 self.blocks.append(DecoderBlockWithSkip( decoder_dims[0], decoder_dims[1], skip_channels=embed_dims[3], # stage3的通道数 kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 )) # 第二个block:stage4到stage3,stride=2,skip来自stage2 self.blocks.append(DecoderBlockWithSkip( decoder_dims[1], decoder_dims[2], skip_channels=embed_dims[2], # stage2的通道数 kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第三个block:stage3到stage2,stride=2,skip来自stage1 self.blocks.append(DecoderBlockWithSkip( decoder_dims[2], decoder_dims[3], skip_channels=embed_dims[1], # stage1的通道数 kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第四个block:stage2到stage1,stride=2,skip来自stage0 self.blocks.append(DecoderBlockWithSkip( decoder_dims[3], 64, # 输出到64通道 skip_channels=embed_dims[0], # stage0的通道数 kernel_size=3, stride=2, padding=1, output_padding=1 )) else: # 使用普通的DecoderBlock,第一个block使用大步长 self.blocks.append(DecoderBlock( decoder_dims[0], decoder_dims[1], kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 )) self.blocks.append(DecoderBlock( decoder_dims[1], decoder_dims[2], kernel_size=3, stride=2, padding=1, output_padding=1 )) self.blocks.append(DecoderBlock( decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第四个block:增加到64通道 self.blocks.append(DecoderBlock( decoder_dims[3], 64, kernel_size=3, stride=2, padding=1, output_padding=1 )) # 改进的最终输出层:不使用反卷积,只进行特征精炼 # 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合 self.final_block = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True) # 移除Tanh,让输出在任意范围,由损失函数和归一化处理 ) def forward(self, x, skip_features=None): """ Args: x: input tensor of shape [B, embed_dims[-1], H/32, W/32] skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0] each of shape [B, C, H', W'] where C matches encoder dims """ if self.use_skip: if skip_features is None: raise ValueError("skip_features must be provided when use_skip=True") # 确保有4个skip features assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}" # 反转顺序以匹配解码器:stage3, stage2, stage1, stage0 skip_features = skip_features[::-1] # 调整skip特征的尺寸以匹配新的上采样策略 adjusted_skip_features = [] for i, skip in enumerate(skip_features): if skip is not None: # 计算目标尺寸:4, 2, 2, 2倍上采样 upsample_factors = [4, 2, 2, 2] target_height = x.shape[2] * upsample_factors[i] target_width = x.shape[3] * upsample_factors[i] if skip.shape[2:] != (target_height, target_width): skip = torch.nn.functional.interpolate( skip, size=(target_height, target_width), mode='bilinear', align_corners=False ) adjusted_skip_features.append(skip) # 四个block使用skip connections for i in range(4): x = self.blocks[i](x, adjusted_skip_features[i]) else: # 不使用skip connections for i in range(4): x = self.blocks[i](x) # 最终输出层:只进行特征精炼,不上采样 x = self.final_block(x) return x class SwiftFormerTemporal(nn.Module): """ SwiftFormer with temporal input for frame prediction. Input: [B, num_frames, H, W] (Y channel only) Output: predicted frame [B, 1, H, W] and optional representation """ def __init__(self, model_name='XS', num_frames=3, use_decoder=True, use_skip=True, # 新增:是否使用skip connections use_representation_head=False, representation_dim=128, return_features=False, **kwargs): super().__init__() # Get model configuration layers = SwiftFormer_depth[model_name] embed_dims = SwiftFormer_width[model_name] # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder self.use_skip = use_skip # 保存skip connections设置 self.use_representation_head = use_representation_head self.return_features = return_features # Modify stem to accept multiple frames (only Y channel) in_channels = num_frames self.patch_embed = stem(in_channels, embed_dims[0]) # Build encoder network (same as SwiftFormer) network = [] for i in range(len(layers)): stage = Stage(embed_dims[i], i, layers, mlp_ratio=4, act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1) network.append(stage) if i >= len(layers) - 1: break if embed_dims[i] != embed_dims[i + 1]: network.append( Embedding( patch_size=3, stride=2, padding=1, in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] ) ) self.network = nn.ModuleList(network) self.norm = nn.BatchNorm2d(embed_dims[-1]) # Frame prediction decoder if use_decoder: self.decoder = FramePredictionDecoder( embed_dims, output_channels=1, use_skip=use_skip # 传递skip connections设置 ) # Representation head for pose/velocity prediction if use_representation_head: self.representation_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(embed_dims[-1], representation_dim), nn.ReLU(), nn.Linear(representation_dim, representation_dim) ) else: self.representation_head = None self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): # 使用Kaiming初始化,适合ReLU/LeakyReLU nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # 反卷积层使用特定的初始化 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): """Forward through encoder network, return list of stage features if return_features else final output""" if self.return_features or self.use_skip: features = [] stage_idx = 0 for idx, block in enumerate(self.network): x = block(x) # 收集每个stage的输出(stage0, stage1, stage2, stage3) # 根据SwiftFormer结构,stage在索引0,2,4,6位置 if idx in [0, 2, 4, 6]: features.append(x) stage_idx += 1 return x, features else: for block in self.network: x = block(x) return x def forward(self, x): """ Args: x: input frames of shape [B, num_frames, H, W] Returns: If return_features is False: pred_frame: predicted frame [B, 1, H, W] (or None) representation: optional representation vector [B, representation_dim] (or None) If return_features is True: pred_frame, representation, features (list of stage features) """ # Encode x = self.patch_embed(x) if self.return_features or self.use_skip: x, features = self.forward_tokens(x) else: x = self.forward_tokens(x) x = self.norm(x) # Get representation if needed representation = None if self.representation_head is not None: representation = self.representation_head(x) # Decode to frame pred_frame = None if self.use_decoder: if self.use_skip: # 提取用于skip connections的特征 # features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3 # 根据SwiftFormer结构,应该有4个stage特征 if len(features) >= 4: # 取四个stage的特征:stage0, stage1, stage2, stage3 skip_features = [features[0], features[1], features[2], features[3]] else: # 如果特征不够,使用可用的特征 skip_features = features[:4] # 如果特征仍然不够,使用None填充 while len(skip_features) < 4: skip_features.append(None) pred_frame = self.decoder(x, skip_features) else: pred_frame = self.decoder(x) if self.return_features: return pred_frame, representation, features else: return pred_frame, representation # Factory functions for different model sizes def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs): return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs) def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs): return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs) def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs): return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs) def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs): return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)