更新模型结构,大步长反卷积后移,启用BN和tanh
This commit is contained in:
@@ -147,15 +147,15 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
|||||||
print(target_frame)
|
print(target_frame)
|
||||||
print(pred_frame)
|
print(pred_frame)
|
||||||
|
|
||||||
# debug print - 改进为更有信息量的输出
|
# # debug print - 改进为更有信息量的输出
|
||||||
if isinstance(pred_frame, np.ndarray):
|
# if isinstance(pred_frame, np.ndarray):
|
||||||
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||||||
# 检查是否有大量值在127.5附近
|
# # 检查是否有大量值在127.5附近
|
||||||
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
||||||
percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
||||||
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
||||||
else:
|
# else:
|
||||||
print(pred_frame)
|
# print(pred_frame)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
@@ -347,10 +347,6 @@ def main(args):
|
|||||||
except (pickle.UnpicklingError, TypeError) as e:
|
except (pickle.UnpicklingError, TypeError) as e:
|
||||||
print(f"使用weights_only=False加载失败: {e}")
|
print(f"使用weights_only=False加载失败: {e}")
|
||||||
print("尝试使用torch.serialization.add_safe_globals...")
|
print("尝试使用torch.serialization.add_safe_globals...")
|
||||||
# from argparse import Namespace
|
|
||||||
# # 添加安全全局变量
|
|
||||||
# torch.serialization.add_safe_globals([Namespace])
|
|
||||||
# checkpoint = torch.load(args.resume, map_location='cpu')
|
|
||||||
|
|
||||||
# 处理状态字典(可能包含'module.'前缀)
|
# 处理状态字典(可能包含'module.'前缀)
|
||||||
if 'model' in checkpoint:
|
if 'model' in checkpoint:
|
||||||
|
|||||||
@@ -21,71 +21,79 @@ class DecoderBlock(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
output_padding=output_padding,
|
output_padding=output_padding,
|
||||||
bias=True # 启用bias,因为移除了BN
|
bias=False # 禁用bias,因为使用BN
|
||||||
)
|
)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||||
kernel_size=3, padding=1, bias=True)
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||||
kernel_size=3, padding=1, bias=True)
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
# 残差路径:如果需要改变通道数或空间尺寸
|
# 残差路径:如果需要改变通道数或空间尺寸
|
||||||
self.shortcut = nn.Identity()
|
self.shortcut = nn.Identity()
|
||||||
if in_channels != out_channels or stride != 1:
|
if in_channels != out_channels or stride != 1:
|
||||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
||||||
if stride == 1:
|
if stride == 1:
|
||||||
self.shortcut = nn.Conv2d(in_channels, out_channels,
|
self.shortcut = nn.Sequential(
|
||||||
kernel_size=1, bias=True)
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||||
|
nn.BatchNorm2d(out_channels)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.shortcut = nn.ConvTranspose2d(
|
self.shortcut = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(
|
||||||
in_channels, out_channels,
|
in_channels, out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=0,
|
padding=0,
|
||||||
output_padding=output_padding,
|
output_padding=output_padding,
|
||||||
bias=True
|
bias=False
|
||||||
|
),
|
||||||
|
nn.BatchNorm2d(out_channels)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用LeakyReLU避免死亡神经元
|
# 使用ReLU激活函数
|
||||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
# 初始化权重
|
# 初始化权重
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
# 初始化反卷积层
|
# 初始化反卷积层
|
||||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if self.conv_transpose.bias is not None:
|
|
||||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
|
||||||
|
|
||||||
# 初始化卷积层
|
# 初始化卷积层
|
||||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if self.conv1.bias is not None:
|
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||||
nn.init.constant_(self.conv1.bias, 0)
|
|
||||||
|
|
||||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.conv2.bias is not None:
|
|
||||||
nn.init.constant_(self.conv2.bias, 0)
|
|
||||||
|
|
||||||
# 初始化shortcut
|
# 初始化shortcut
|
||||||
if not isinstance(self.shortcut, nn.Identity):
|
if not isinstance(self.shortcut, nn.Identity):
|
||||||
if isinstance(self.shortcut, nn.Conv2d):
|
# shortcut现在是Sequential,需要初始化其中的卷积层
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
for module in self.shortcut:
|
||||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if self.shortcut.bias is not None:
|
|
||||||
nn.init.constant_(self.shortcut.bias, 0)
|
# 初始化BN层(使用默认初始化)
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = self.shortcut(x)
|
identity = self.shortcut(x)
|
||||||
|
|
||||||
# 主路径
|
# 主路径
|
||||||
x = self.conv_transpose(x)
|
x = self.conv_transpose(x)
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
x = self.bn2(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
# 残差连接
|
# 残差连接
|
||||||
x = x + identity
|
x = x + identity
|
||||||
@@ -102,34 +110,35 @@ class FramePredictionDecoder(nn.Module):
|
|||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
|
|
||||||
# 使用普通的DecoderBlock,第一个block使用大步长
|
# 调整顺序:将stride=4放在倒数第二的位置
|
||||||
|
# 第一个block:stride=2 (220 -> 112)
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[0], decoder_dims[1],
|
decoder_dims[0], decoder_dims[1],
|
||||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
|
# 第二个block:stride=2 (112 -> 56)
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[1], decoder_dims[2],
|
decoder_dims[1], decoder_dims[2],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
|
# 第三个block:stride=2 (56 -> 48)
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[2], decoder_dims[3],
|
decoder_dims[2], decoder_dims[3],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# 第四个block:增加到64通道
|
# 第四个block:stride=4 (48 -> 64),放在倒数第二的位置
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[3], 64,
|
decoder_dims[3], 64,
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||||
))
|
))
|
||||||
|
|
||||||
# 改进的最终输出层:不使用反卷积,只进行特征精炼
|
|
||||||
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
|
|
||||||
self.final_block = nn.Sequential(
|
self.final_block = nn.Sequential(
|
||||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
|
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||||
# 移除Tanh,让输出在任意范围,由损失函数和归一化处理
|
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -156,7 +165,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
model_name='XS',
|
model_name='XS',
|
||||||
num_frames=3,
|
num_frames=3,
|
||||||
use_decoder=True,
|
use_decoder=True,
|
||||||
return_features=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -167,7 +175,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
# Store configuration
|
# Store configuration
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.use_decoder = use_decoder
|
self.use_decoder = use_decoder
|
||||||
self.return_features = return_features
|
|
||||||
|
|
||||||
# Modify stem to accept multiple frames (only Y channel)
|
# Modify stem to accept multiple frames (only Y channel)
|
||||||
in_channels = num_frames
|
in_channels = num_frames
|
||||||
@@ -207,13 +214,13 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
# 使用Kaiming初始化,适合ReLU/LeakyReLU
|
# 使用Kaiming初始化,适合ReLU
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
elif isinstance(m, nn.ConvTranspose2d):
|
||||||
# 反卷积层使用特定的初始化
|
# 反卷积层使用特定的初始化
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
@@ -221,19 +228,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
def forward_tokens(self, x):
|
def forward_tokens(self, x):
|
||||||
"""Forward through encoder network, return list of stage features if return_features else final output"""
|
|
||||||
if self.return_features:
|
|
||||||
features = []
|
|
||||||
stage_idx = 0
|
|
||||||
for idx, block in enumerate(self.network):
|
|
||||||
x = block(x)
|
|
||||||
# 收集每个stage的输出(stage0, stage1, stage2, stage3)
|
|
||||||
# 根据SwiftFormer结构,stage在索引0,2,4,6位置
|
|
||||||
if idx in [0, 2, 4, 6]:
|
|
||||||
features.append(x)
|
|
||||||
stage_idx += 1
|
|
||||||
return x, features
|
|
||||||
else:
|
|
||||||
for block in self.network:
|
for block in self.network:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
return x
|
return x
|
||||||
@@ -243,16 +237,10 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: input frames of shape [B, num_frames, H, W]
|
x: input frames of shape [B, num_frames, H, W]
|
||||||
Returns:
|
Returns:
|
||||||
If return_features is False:
|
|
||||||
pred_frame: predicted frame [B, 1, H, W] (or None)
|
pred_frame: predicted frame [B, 1, H, W] (or None)
|
||||||
If return_features is True:
|
|
||||||
pred_frame, features (list of stage features)
|
|
||||||
"""
|
"""
|
||||||
# Encode
|
# Encode
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.return_features:
|
|
||||||
x, features = self.forward_tokens(x)
|
|
||||||
else:
|
|
||||||
x = self.forward_tokens(x)
|
x = self.forward_tokens(x)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
@@ -261,9 +249,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
if self.use_decoder:
|
if self.use_decoder:
|
||||||
pred_frame = self.decoder(x)
|
pred_frame = self.decoder(x)
|
||||||
|
|
||||||
if self.return_features:
|
|
||||||
return pred_frame, features
|
|
||||||
else:
|
|
||||||
return pred_frame
|
return pred_frame
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user