更新模型结构,大步长反卷积后移,启用BN和tanh

This commit is contained in:
2026-01-15 21:12:27 +08:00
parent df703638da
commit a92a0b29e9
2 changed files with 67 additions and 86 deletions

View File

@@ -21,71 +21,79 @@ class DecoderBlock(nn.Module):
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True # 用bias因为移除了BN
bias=False # 用bias因为使用BN
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=True)
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
self.shortcut = nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=False
),
nn.BatchNorm2d(out_channels)
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
# shortcut现在是Sequential需要初始化其中的卷积层
for module in self.shortcut:
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
# 残差连接
x = x + identity
@@ -102,34 +110,35 @@ class FramePredictionDecoder(nn.Module):
self.blocks = nn.ModuleList()
# 使用普通的DecoderBlock第一个block使用大步长
# 调整顺序将stride=4放在倒数第二的位置
# 第一个blockstride=2 (220 -> 112)
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第二个blockstride=2 (112 -> 56)
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstride=2 (56 -> 48)
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个block增加到64通道
# 第四个blockstride=4 (48 -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=2, padding=1, output_padding=1
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
))
# 改进的最终输出层:不使用反卷积,只进行特征精炼
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
# 移除Tanh,让输出在任意范围,由损失函数和归一化处理
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
)
def forward(self, x):
@@ -156,7 +165,6 @@ class SwiftFormerTemporal(nn.Module):
model_name='XS',
num_frames=3,
use_decoder=True,
return_features=False,
**kwargs):
super().__init__()
@@ -167,7 +175,6 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
@@ -207,13 +214,13 @@ class SwiftFormerTemporal(nn.Module):
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU/LeakyReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
# 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
@@ -221,39 +228,20 @@ class SwiftFormerTemporal(nn.Module):
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features:
features = []
stage_idx = 0
for idx, block in enumerate(self.network):
x = block(x)
# 收集每个stage的输出stage0, stage1, stage2, stage3
# 根据SwiftFormer结构stage在索引0,2,4,6位置
if idx in [0, 2, 4, 6]:
features.append(x)
stage_idx += 1
return x, features
else:
for block in self.network:
x = block(x)
return x
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 1, H, W] (or None)
If return_features is True:
pred_frame, features (list of stage features)
pred_frame: predicted frame [B, 1, H, W] (or None)
"""
# Encode
x = self.patch_embed(x)
if self.return_features:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.forward_tokens(x)
x = self.norm(x)
# Decode to frame
@@ -261,10 +249,7 @@ class SwiftFormerTemporal(nn.Module):
if self.use_decoder:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, features
else:
return pred_frame
return pred_frame
# Factory functions for different model sizes