清理代码,删除跳连接部分

This commit is contained in:
2026-01-11 13:25:34 +08:00
parent c5502cc87c
commit df703638da
3 changed files with 68 additions and 268 deletions

View File

@@ -45,6 +45,15 @@ def denormalize(tensor):
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
# return tensor
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
@@ -134,6 +143,10 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
ax.set_title('Predicted')
ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
# debug print - 改进为更有信息量的输出
if isinstance(pred_frame, np.ndarray):
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
@@ -161,8 +174,8 @@ def evaluate_model(model, data_loader, device, args):
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
# model.eval()
model.train() # 临时使用训练模式
model.eval()
# model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
@@ -183,10 +196,11 @@ def evaluate_model(model, data_loader, device, args):
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames, _ = model(input_frames)
pred_frames = model(input_frames)
# 反归一化用于指标计算
pred_denorm = denormalize(pred_frames) # [B, 1, H, W]
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
@@ -309,8 +323,6 @@ def main(args):
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
@@ -335,10 +347,10 @@ def main(args):
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
from argparse import Namespace
# 添加安全全局变量
torch.serialization.add_safe_globals([Namespace])
checkpoint = torch.load(args.resume, map_location='cpu')
# from argparse import Namespace
# # 添加安全全局变量
# torch.serialization.add_safe_globals([Namespace])
# checkpoint = torch.load(args.resume, map_location='cpu')
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
@@ -462,10 +474,6 @@ def get_args_parser():
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
parser.add_argument('--use-representation-head', action='store_true',
help='使用表示头进行姿态/速度预测')
parser.add_argument('--representation-dim', default=128, type=int,
help='表示向量的维度')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,