""" 评估脚本 for SwiftFormerTemporal frame prediction 输出预测图(注意反归一化)以及对应指标(mse&ssim&psnr) """ import argparse import os import torch import torch.nn as nn import pickle import numpy as np import random from pathlib import Path import json import matplotlib.pyplot as plt from PIL import Image import torchvision.transforms as transforms import torch.backends.cudnn as cudnn from util.video_dataset import VideoFrameDataset from models.swiftformer_temporal import ( SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 ) # 导入SSIM和PSNR计算 from skimage.metrics import structural_similarity as ssim from skimage.metrics import peak_signal_noise_ratio as psnr import warnings warnings.filterwarnings('ignore') def denormalize(tensor): """ 将[-1, 1]范围的张量反归一化到[0, 255]范围 Args: tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1] Returns: 反归一化后的张量,值在[0, 255] """ # clip 到 [-1, 1] 范围 tensor = tensor.clamp(-1, 1) # [-1, 1] -> [0, 1] tensor = (tensor + 1) / 2 # [0, 1] -> [0, 255] tensor = tensor * 255 return tensor.clamp(0, 255) def calculate_metrics(pred, target, debug=False): """ 计算MSE, SSIM, PSNR指标 Args: pred: 预测图像,形状[H, W],值在[0, 255] target: 目标图像,形状[H, W],值在[0, 255] debug: 是否输出调试信息 Returns: mse, ssim_value, psnr_value """ # 转换为numpy数组 pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred target_np = target.cpu().numpy() if torch.is_tensor(target) else target # 确保是2D数组 if pred_np.ndim == 3: pred_np = pred_np.squeeze(0) if target_np.ndim == 3: target_np = target_np.squeeze(0) if debug: print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}") print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}") print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}") # 计算MSE - 修复错误的tmp公式 # 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2 # 正确公式: 直接计算像素差的平方 mse = np.mean((pred_np - target_np) ** 2) # 同时计算错误公式的MSE用于对比 tmp = 1 - (pred_np - target_np) / 255 * 2 wrong_mse = np.mean(tmp**2) if debug: print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}") # 计算SSIM (数据范围0-255) data_range = 255.0 ssim_value = ssim(pred_np, target_np, data_range=data_range) # 计算PSNR psnr_value = psnr(target_np, pred_np, data_range=data_range) return mse, ssim_value, psnr_value def save_comparison_figure(input_frames, target_frame, pred_frame, save_path, input_frame_indices=None, target_frame_index=None): """ 保存对比图:输入帧、目标帧、预测帧 Args: input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255] target_frame: 目标帧,形状[H, W],值在[0, 255] pred_frame: 预测帧,形状[H, W],值在[0, 255] save_path: 保存路径 input_frame_indices: 输入帧的索引列表(可选) target_frame_index: 目标帧索引(可选) """ num_input = len(input_frames) fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4)) # 绘制输入帧 for i in range(num_input): ax = axes[i] ax.imshow(input_frames[i], cmap='gray') if input_frame_indices is not None: ax.set_title(f'Input Frame {input_frame_indices[i]}') else: ax.set_title(f'Input {i+1}') ax.axis('off') # 绘制目标帧 ax = axes[num_input] ax.imshow(target_frame, cmap='gray') if target_frame_index is not None: ax.set_title(f'Target Frame {target_frame_index}') else: ax.set_title('Target') ax.axis('off') # 绘制预测帧 ax = axes[num_input + 1] ax.imshow(pred_frame, cmap='gray') ax.set_title('Predicted') ax.axis('off') # debug print - 改进为更有信息量的输出 if isinstance(pred_frame, np.ndarray): print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") # 检查是否有大量值在127.5附近 mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0 percent_near_127_5 = np.mean(mask_near_127_5) * 100 print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%") else: print(pred_frame) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() def evaluate_model(model, data_loader, device, args): """ 评估模型并计算指标 Args: model: 训练好的模型 data_loader: 数据加载器 device: 设备 args: 命令行参数 Returns: metrics_dict: 包含所有指标的字典 sample_results: 示例结果用于可视化 """ # model.eval() model.train() # 临时使用训练模式 # 初始化指标累加器 total_mse = 0.0 total_ssim = 0.0 total_psnr = 0.0 total_samples = 0 # 存储示例结果用于可视化(使用蓄水池抽样随机选择) sample_results = [] max_samples_to_save = args.num_samples_to_save max_samples = args.max_samples # 用于蓄水池抽样的计数器(已处理的样本数,不包括因max_samples限制而跳过的样本) sample_count = 0 with torch.no_grad(): for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) # 前向传播 pred_frames, _ = model(input_frames) # 反归一化用于指标计算 pred_denorm = denormalize(pred_frames) # [B, 1, H, W] target_denorm = denormalize(target_frames) # [B, 1, H, W] batch_size = input_frames.size(0) # 计算每个样本的指标 for i in range(batch_size): # 检查是否达到最大样本数限制 if max_samples is not None and total_samples >= max_samples: break pred_i = pred_denorm[i] # [1, H, W] target_i = target_denorm[i] # [1, H, W] # 对第一个样本启用调试 debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0) if debug_mode: print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}") print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}") print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}") print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}") mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode) total_mse += mse total_ssim += ssim_value total_psnr += psnr_value total_samples += 1 sample_count += 1 # 构建样本数据字典 input_denorm = denormalize(input_frames[i]) # [num_frames, H, W] # 分离输入帧 input_frames_list = [] for j in range(args.num_frames): input_frame_j = input_denorm[j].squeeze(0) # [H, W] input_frames_list.append(input_frame_j.cpu().numpy()) sample_data = { 'input_frames': input_frames_list, 'target_frame': target_i.squeeze(0).cpu().numpy(), 'pred_frame': pred_i.squeeze(0).cpu().numpy(), 'metrics': { 'mse': mse, 'ssim': ssim_value, 'psnr': psnr_value }, 'batch_idx': batch_idx, 'sample_idx': i } # 蓄水池抽样 (Reservoir Sampling) if sample_count <= max_samples_to_save: # 蓄水池未满,直接加入 sample_results.append(sample_data) else: # 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置 r = random.randint(0, sample_count - 1) if r < max_samples_to_save: sample_results[r] = sample_data # 检查是否达到最大样本数限制 if max_samples is not None and total_samples >= max_samples: print(f"达到最大样本数限制: {max_samples}") break # 进度打印 if (batch_idx + 1) % 10 == 0: print(f'Processed {batch_idx + 1} batches, {total_samples} samples') # 计算平均指标 if total_samples > 0: avg_mse = float(total_mse / total_samples) avg_ssim = float(total_ssim / total_samples) avg_psnr = float(total_psnr / total_samples) else: avg_mse = avg_ssim = avg_psnr = 0.0 metrics_dict = { 'mse': avg_mse, 'ssim': avg_ssim, 'psnr': avg_psnr, 'num_samples': total_samples } return metrics_dict, sample_results def main(args): print("评估参数:", args) device = torch.device(args.device) # 设置随机种子 torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) cudnn.benchmark = True # 构建数据集 print("构建数据集...") dataset_val = VideoFrameDataset( root_dir=args.data_path, num_frames=args.num_frames, frame_size=args.frame_size, is_train=False, max_interval=args.max_interval ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, shuffle=False, drop_last=False ) # 创建模型 print(f"创建模型: {args.model}") model_kwargs = { 'num_frames': args.num_frames, 'use_representation_head': args.use_representation_head, 'representation_dim': args.representation_dim, } if args.model == 'SwiftFormerTemporal_XS': model = SwiftFormerTemporal_XS(**model_kwargs) elif args.model == 'SwiftFormerTemporal_S': model = SwiftFormerTemporal_S(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L1': model = SwiftFormerTemporal_L1(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L3': model = SwiftFormerTemporal_L3(**model_kwargs) else: raise ValueError(f"未知模型: {args.model}") model.to(device) # 加载检查点 if args.resume: print(f"加载检查点: {args.resume}") try: # 尝试使用weights_only=False加载(PyTorch 2.6+需要) checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False) except (pickle.UnpicklingError, TypeError) as e: print(f"使用weights_only=False加载失败: {e}") print("尝试使用torch.serialization.add_safe_globals...") from argparse import Namespace # 添加安全全局变量 torch.serialization.add_safe_globals([Namespace]) checkpoint = torch.load(args.resume, map_location='cpu') # 处理状态字典(可能包含'module.'前缀) if 'model' in checkpoint: state_dict = checkpoint['model'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # 移除'module.'前缀(如果存在) if hasattr(model, 'module'): model.module.load_state_dict(state_dict) else: # 如果状态字典有'module.'前缀但模型没有,需要移除前缀 if any(key.startswith('module.') for key in state_dict.keys()): from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v state_dict = new_state_dict model.load_state_dict(state_dict) print(f"检查点加载成功,epoch: {checkpoint.get('epoch', 'unknown')}") else: print("警告: 未提供检查点路径,使用随机初始化的模型") # 创建输出目录 output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # 评估模型 print("开始评估...") metrics, sample_results = evaluate_model(model, data_loader_val, device, args) # 打印指标 print("\n" + "="*50) print("评估结果:") print(f"MSE: {metrics['mse']:.6f}") print(f"SSIM: {metrics['ssim']:.6f}") print(f"PSNR: {metrics['psnr']:.6f} dB") print(f"样本数量: {metrics['num_samples']}") print("="*50) # 保存指标到JSON文件 metrics_file = output_dir / 'evaluation_metrics.json' with open(metrics_file, 'w') as f: json.dump(metrics, f, indent=4) print(f"指标已保存到: {metrics_file}") # 保存示例可视化 if sample_results: print(f"\n保存 {len(sample_results)} 个示例可视化...") samples_dir = output_dir / 'sample_predictions' samples_dir.mkdir(exist_ok=True) for i, sample in enumerate(sample_results): save_path = samples_dir / f'sample_{i:03d}.png' # 生成输入帧索引(假设连续) input_frame_indices = list(range(1, args.num_frames + 1)) target_frame_index = args.num_frames + 1 save_comparison_figure( sample['input_frames'], sample['target_frame'], sample['pred_frame'], save_path, input_frame_indices=input_frame_indices, target_frame_index=target_frame_index ) # 保存该样本的指标 sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt' with open(sample_metrics_file, 'w') as f: f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n") f.write(f"MSE: {sample['metrics']['mse']:.6f}\n") f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n") f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n") print(f"示例可视化已保存到: {samples_dir}") # 生成汇总报告 report_file = output_dir / 'evaluation_report.txt' with open(report_file, 'w') as f: f.write("SwiftFormerTemporal 帧预测评估报告\n") f.write("="*50 + "\n") f.write(f"模型: {args.model}\n") f.write(f"检查点: {args.resume}\n") f.write(f"数据集: {args.data_path}\n") f.write(f"输入帧数: {args.num_frames}\n") f.write(f"帧大小: {args.frame_size}\n") f.write(f"批次大小: {args.batch_size}\n") f.write(f"样本总数: {metrics['num_samples']}\n\n") f.write("评估指标:\n") f.write(f" MSE: {metrics['mse']:.6f}\n") f.write(f" SSIM: {metrics['ssim']:.6f}\n") f.write(f" PSNR: {metrics['psnr']:.6f} dB\n") print(f"评估报告已保存到: {report_file}") print("\n评估完成!") def get_args_parser(): parser = argparse.ArgumentParser( 'SwiftFormerTemporal 评估脚本', add_help=False) # 数据集参数 parser.add_argument('--data-path', default='./videos', type=str, help='视频数据集路径') parser.add_argument('--num-frames', default=3, type=int, help='输入帧数 (T)') parser.add_argument('--frame-size', default=224, type=int, help='输入帧大小') parser.add_argument('--max-interval', default=4, type=int, help='连续帧之间的最大间隔') # 模型参数 parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='要评估的模型名称') parser.add_argument('--use-representation-head', action='store_true', help='使用表示头进行姿态/速度预测') parser.add_argument('--representation-dim', default=128, type=int, help='表示向量的维度') # 评估参数 parser.add_argument('--batch-size', default=16, type=int, help='评估批次大小') parser.add_argument('--num-samples-to-save', default=10, type=int, help='保存可视化的样本数量') parser.add_argument('--max-samples', default=None, type=int, help='最大评估样本数(None表示全部)') # 系统参数 parser.add_argument('--output-dir', default='./evaluation_results', help='保存结果的路径') parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='使用的设备') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='检查点路径') parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--pin-mem', action='store_true', help='在DataLoader中固定CPU内存') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) return parser if __name__ == '__main__': parser = argparse.ArgumentParser( 'SwiftFormerTemporal 评估', parents=[get_args_parser()]) args = parser.parse_args() # 确保输出目录存在 if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)