""" SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction """ import torch import torch.nn as nn from .swiftformer import ( SwiftFormer, SwiftFormer_depth, SwiftFormer_width, stem, Embedding, Stage ) from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): """Upsampling block for frame prediction decoder""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() self.conv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.conv(x))) class FramePredictionDecoder(nn.Module): """Lightweight decoder for frame prediction with optional skip connections""" def __init__(self, embed_dims, output_channels=1, use_skip=False): super().__init__() self.use_skip = use_skip # Reverse the embed_dims for decoder decoder_dims = embed_dims[::-1] self.blocks = nn.ModuleList() # First upsampling from bottleneck to stage4 resolution self.blocks.append(DecoderBlock( decoder_dims[0], decoder_dims[1], kernel_size=3, stride=2, padding=1, output_padding=1 )) # stage4 to stage3 self.blocks.append(DecoderBlock( decoder_dims[1], decoder_dims[2], kernel_size=3, stride=2, padding=1, output_padding=1 )) # stage3 to stage2 self.blocks.append(DecoderBlock( decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) # stage2 to original resolution (now 8x upsampling total with stride 4) self.blocks.append(nn.Sequential( nn.ConvTranspose2d( decoder_dims[3], 32, kernel_size=3, stride=4, padding=1, output_padding=3 ), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, output_channels, kernel_size=3, padding=1), nn.Tanh() # Output in [-1, 1] range )) # If using skip connections, we need to adjust input channels for each block if use_skip: # We'll modify the first three blocks to accept concatenated features # Instead of modifying existing blocks, we'll replace them with custom blocks # For simplicity, we'll keep the same architecture but forward will handle concatenation pass def forward(self, x, skip_features=None): """ Args: x: input tensor of shape [B, embed_dims[-1], H/32, W/32] skip_features: list of encoder features from stages [stage2, stage1, stage0] each of shape [B, C, H', W'] where C matches decoder dims? """ if self.use_skip and skip_features is not None: # Ensure we have exactly 3 skip features (for the first three blocks) assert len(skip_features) == 3, "Need 3 skip features for skip connections" # Reverse skip_features to match decoder order: stage2, stage1, stage0 # skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4) skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0 for i, block in enumerate(self.blocks): if self.use_skip and skip_features is not None and i < 3: # Concatenate skip feature along channel dimension # Ensure spatial dimensions match (they should because of upsampling) x = torch.cat([x, skip_features[i]], dim=1) # Need to adjust block to accept extra channels? We'll create a separate block. # For now, we'll just pass through, but this will cause channel mismatch. # Instead, we should have created custom blocks with appropriate in_channels. # This is a placeholder; we need to implement properly. pass x = block(x) return x class SwiftFormerTemporal(nn.Module): """ SwiftFormer with temporal input for frame prediction. Input: [B, num_frames, H, W] (Y channel only) Output: predicted frame [B, 1, H, W] and optional representation """ def __init__(self, model_name='XS', num_frames=3, use_decoder=True, use_representation_head=False, representation_dim=128, return_features=False, **kwargs): super().__init__() # Get model configuration layers = SwiftFormer_depth[model_name] embed_dims = SwiftFormer_width[model_name] # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder self.use_representation_head = use_representation_head self.return_features = return_features # Modify stem to accept multiple frames (only Y channel) in_channels = num_frames self.patch_embed = stem(in_channels, embed_dims[0]) # Build encoder network (same as SwiftFormer) network = [] for i in range(len(layers)): stage = Stage(embed_dims[i], i, layers, mlp_ratio=4, act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1) network.append(stage) if i >= len(layers) - 1: break if embed_dims[i] != embed_dims[i + 1]: network.append( Embedding( patch_size=3, stride=2, padding=1, in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] ) ) self.network = nn.ModuleList(network) self.norm = nn.BatchNorm2d(embed_dims[-1]) # Frame prediction decoder if use_decoder: self.decoder = FramePredictionDecoder(embed_dims, output_channels=1) # Representation head for pose/velocity prediction if use_representation_head: self.representation_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(embed_dims[-1], representation_dim), nn.ReLU(), nn.Linear(representation_dim, representation_dim) ) else: self.representation_head = None self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): """Forward through encoder network, return list of stage features if return_features else final output""" if self.return_features: features = [] for idx, block in enumerate(self.network): x = block(x) # Collect output after each stage (indices 0,2,4,6 correspond to stages) if idx in [0, 2, 4, 6]: features.append(x) return x, features else: for block in self.network: x = block(x) return x def forward(self, x): """ Args: x: input frames of shape [B, num_frames, H, W] Returns: If return_features is False: pred_frame: predicted frame [B, 1, H, W] (or None) representation: optional representation vector [B, representation_dim] (or None) If return_features is True: pred_frame, representation, features (list of stage features) """ # Encode x = self.patch_embed(x) if self.return_features: x, features = self.forward_tokens(x) else: x = self.forward_tokens(x) x = self.norm(x) # Get representation if needed representation = None if self.representation_head is not None: representation = self.representation_head(x) # Decode to frame pred_frame = None if self.use_decoder: pred_frame = self.decoder(x) if self.return_features: return pred_frame, representation, features else: return pred_frame, representation # Factory functions for different model sizes def SwiftFormerTemporal_XS(num_frames=3, **kwargs): return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_S(num_frames=3, **kwargs): return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_L1(num_frames=3, **kwargs): return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs) def SwiftFormerTemporal_L3(num_frames=3, **kwargs): return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)