删除残差路径和shortcut,镜像问题仍存在

This commit is contained in:
2026-01-16 15:21:47 +08:00
parent a92a0b29e9
commit 543beefa2a
3 changed files with 24 additions and 81 deletions

View File

@@ -45,7 +45,6 @@ def denormalize(tensor):
# [0, 1] -> [0, 255] # [0, 1] -> [0, 255]
tensor = tensor * 255 tensor = tensor * 255
return tensor.clamp(0, 255) return tensor.clamp(0, 255)
# return tensor
def minmax_denormalize(tensor): def minmax_denormalize(tensor):
tensor_min = tensor.min() tensor_min = tensor.min()
@@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
if target_np.ndim == 3: if target_np.ndim == 3:
target_np = target_np.squeeze(0) target_np = target_np.squeeze(0)
if debug: # if debug:
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}") # print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}") # print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}") # print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
# 计算MSE - 修复错误的tmp公式
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
# 正确公式: 直接计算像素差的平方
mse = np.mean((pred_np - target_np) ** 2) mse = np.mean((pred_np - target_np) ** 2)
# 同时计算错误公式的MSE用于对比
tmp = 1 - (pred_np - target_np) / 255 * 2
wrong_mse = np.mean(tmp**2)
if debug:
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
# 计算SSIM (数据范围0-255)
data_range = 255.0 data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range) ssim_value = ssim(pred_np, target_np, data_range=data_range)
# 计算PSNR
psnr_value = psnr(target_np, pred_np, data_range=data_range) psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value return mse, ssim_value, psnr_value
@@ -147,16 +134,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
print(target_frame) print(target_frame)
print(pred_frame) print(pred_frame)
# # debug print - 改进为更有信息量的输出
# if isinstance(pred_frame, np.ndarray):
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
# # 检查是否有大量值在127.5附近
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
# else:
# print(pred_frame)
plt.tight_layout() plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close() plt.close()
@@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
# 对第一个样本启用调试 # 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0) debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
if debug_mode: # if debug_mode:
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}") # print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}") # print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}") # print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}") # print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode) mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse total_mse += mse
total_ssim += ssim_value total_ssim += ssim_value

View File

@@ -43,7 +43,7 @@ def get_args_parser():
help='Number of input frames (T)') help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int, parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size') help='Input frame size')
parser.add_argument('--max-interval', default=4, type=int, parser.add_argument('--max-interval', default=10, type=int,
help='Maximum interval between consecutive frames') help='Maximum interval between consecutive frames')
# Model parameters # Model parameters
@@ -121,7 +121,7 @@ def get_args_parser():
help='start epoch') help='start epoch')
parser.add_argument('--eval', action='store_true', parser.add_argument('--eval', action='store_true',
help='Perform evaluation only') help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--num-workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true', parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader') help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
@@ -264,7 +264,7 @@ def main(args):
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True) args.resume, map_location='cpu', check_hash=True)
else: else:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
@@ -308,7 +308,7 @@ def main(args):
train_stats, global_step = train_one_epoch( train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
model_ema=model_ema, writer=writer, model_ema=model_ema, writer=writer,
global_step=global_step, args=args global_step=global_step, args=args
) )
@@ -356,7 +356,7 @@ def main(args):
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=None, clip_mode='norm', model_ema=None, writer=None, clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs): global_step=0, args=None, **kwargs):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")

View File

@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder with residual connections""" """Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__() super().__init__()
# 主路径:反卷积 + 两个卷积层 # 主路径:反卷积 + 两个卷积层
@@ -31,28 +31,6 @@ class DecoderBlock(nn.Module):
kernel_size=3, padding=1, bias=False) kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels) self.bn3 = nn.BatchNorm2d(out_channels)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=False
),
nn.BatchNorm2d(out_channels)
)
# 使用ReLU激活函数 # 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True) self.activation = nn.ReLU(inplace=True)
@@ -67,13 +45,6 @@ class DecoderBlock(nn.Module):
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
# shortcut现在是Sequential需要初始化其中的卷积层
for module in self.shortcut:
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化 # 初始化BN层使用默认初始化
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
@@ -81,8 +52,6 @@ class DecoderBlock(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
identity = self.shortcut(x)
# 主路径 # 主路径
x = self.conv_transpose(x) x = self.conv_transpose(x)
x = self.bn1(x) x = self.bn1(x)
@@ -94,9 +63,6 @@ class DecoderBlock(nn.Module):
x = self.conv2(x) x = self.conv2(x)
x = self.bn3(x) x = self.bn3(x)
# 残差连接
x = x + identity
x = self.activation(x) x = self.activation(x)
return x return x
@@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction""" """Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1): def __init__(self, embed_dims, output_channels=1):
super().__init__() super().__init__()
# Reverse the embed_dims for decoder # Define decoder dimensions independently (no skip connections)
decoder_dims = embed_dims[::-1] start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
# 调整顺序将stride=4放在倒数第二的位置 # 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
# 第一个blockstride=2 (220 -> 112)
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1], decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# 第二个blockstride=2 (112 -> 56) # 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2], decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# 第三个blockstride=2 (56 -> 48) # 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3], decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# 第四个blockstride=4 (48 -> 64),放在倒数第二的位置 # 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[3], 64, decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里 kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
@@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module):
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True), nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh() # 添加Tanh激活函数约束输出在[-1, 1]范围内 nn.Tanh()
) )
def forward(self, x): def forward(self, x):