Files
asmo_vhead/models/swiftformer_temporal.py

232 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False # 禁用bias因为使用BN
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1):
super().__init__()
# Define decoder dimensions independently (no skip connections)
start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList()
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
))
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
def forward(self, x):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
"""
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
pred_frame: predicted frame [B, 1, H, W] (or None)
"""
# Encode
x = self.patch_embed(x)
x = self.forward_tokens(x)
x = self.norm(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
pred_frame = self.decoder(x)
return pred_frame
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)