""" SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction """ import torch import torch.nn as nn from .swiftformer import ( SwiftFormer, SwiftFormer_depth, SwiftFormer_width, stem, Embedding, Stage ) from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): """Upsampling block for frame prediction decoder without residual connections""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() # 主路径:反卷积 + 两个卷积层 self.conv_transpose = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False # 禁用bias,因为使用BN ) self.bn1 = nn.BatchNorm2d(out_channels) self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels) # 使用ReLU激活函数 self.activation = nn.ReLU(inplace=True) # 初始化权重 self._init_weights() def _init_weights(self): # 初始化反卷积层 nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu') # 初始化卷积层 nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu') # 初始化BN层(使用默认初始化) for m in self.modules(): if isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # 主路径 x = self.conv_transpose(x) x = self.bn1(x) x = self.activation(x) x = self.conv1(x) x = self.bn2(x) x = self.activation(x) x = self.conv2(x) x = self.bn3(x) x = self.activation(x) return x class FramePredictionDecoder(nn.Module): """Improved decoder for frame prediction""" def __init__(self, embed_dims, output_channels=1): super().__init__() # Define decoder dimensions independently (no skip connections) start_dim = embed_dims[-1] decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS self.blocks = nn.ModuleList() # 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1]) self.blocks.append(DecoderBlock( decoder_dims[0], decoder_dims[1], kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2]) self.blocks.append(DecoderBlock( decoder_dims[1], decoder_dims[2], kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3]) self.blocks.append(DecoderBlock( decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) # 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置 self.blocks.append(DecoderBlock( decoder_dims[3], 64, kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里 )) self.final_block = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True), nn.Tanh() ) def forward(self, x): """ Args: x: input tensor of shape [B, embed_dims[-1], H/32, W/32] """ # 不使用skip connections for i in range(4): x = self.blocks[i](x) # 最终输出层:只进行特征精炼,不上采样 x = self.final_block(x) return x class SwiftFormerTemporal(nn.Module): """ SwiftFormer with temporal input for frame prediction. Input: [B, num_frames, H, W] (Y channel only) Output: predicted frame [B, 1, H, W] and optional representation """ def __init__(self, model_name='XS', num_frames=3, use_decoder=True, **kwargs): super().__init__() # Get model configuration layers = SwiftFormer_depth[model_name] embed_dims = SwiftFormer_width[model_name] # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder # Modify stem to accept multiple frames (only Y channel) in_channels = num_frames self.patch_embed = stem(in_channels, embed_dims[0]) # Build encoder network (same as SwiftFormer) network = [] for i in range(len(layers)): stage = Stage(embed_dims[i], i, layers, mlp_ratio=4, act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1) network.append(stage) if i >= len(layers) - 1: break if embed_dims[i] != embed_dims[i + 1]: network.append( Embedding( patch_size=3, stride=2, padding=1, in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] ) ) self.network = nn.ModuleList(network) self.norm = nn.BatchNorm2d(embed_dims[-1]) # Frame prediction decoder if use_decoder: self.decoder = FramePredictionDecoder( embed_dims, output_channels=1 ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): # 使用Kaiming初始化,适合ReLU nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # 反卷积层使用特定的初始化 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): for block in self.network: x = block(x) return x def forward(self, x): """ Args: x: input frames of shape [B, num_frames, H, W] Returns: pred_frame: predicted frame [B, 1, H, W] (or None) """ # Encode x = self.patch_embed(x) x = self.forward_tokens(x) x = self.norm(x) # Decode to frame pred_frame = None if self.use_decoder: pred_frame = self.decoder(x) return pred_frame # Factory functions for different model sizes def SwiftFormerTemporal_XS(num_frames=3, **kwargs): return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_S(num_frames=3, **kwargs): return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_L1(num_frames=3, **kwargs): return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_L3(num_frames=3, **kwargs): return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)