删除残差路径和shortcut,镜像问题仍存在

This commit is contained in:
2026-01-16 15:21:47 +08:00
parent a92a0b29e9
commit 543beefa2a
3 changed files with 24 additions and 81 deletions

View File

@@ -45,7 +45,6 @@ def denormalize(tensor):
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
# return tensor
def minmax_denormalize(tensor):
tensor_min = tensor.min()
@@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
if debug:
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
# if debug:
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
# 计算MSE - 修复错误的tmp公式
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
# 正确公式: 直接计算像素差的平方
mse = np.mean((pred_np - target_np) ** 2)
# 同时计算错误公式的MSE用于对比
tmp = 1 - (pred_np - target_np) / 255 * 2
wrong_mse = np.mean(tmp**2)
if debug:
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
# 计算SSIM (数据范围0-255)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
# 计算PSNR
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
@@ -146,16 +133,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
#debug print
print(target_frame)
print(pred_frame)
# # debug print - 改进为更有信息量的输出
# if isinstance(pred_frame, np.ndarray):
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
# # 检查是否有大量值在127.5附近
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
# else:
# print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
@@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
if debug_mode:
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
# if debug_mode:
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse
total_ssim += ssim_value