Compare commits

..

2 Commits

9 changed files with 981 additions and 317 deletions

View File

@@ -11,9 +11,10 @@ shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32}
BATCH_SIZE=${BATCH_SIZE:-128}
EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3}
# LR=${LR:-1e-3}
LR=${LR:-0.01}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"

503
evaluate_temporal.py Normal file
View File

@@ -0,0 +1,503 @@
"""
评估脚本 for SwiftFormerTemporal frame prediction
输出预测图注意反归一化以及对应指标mse&ssim&psnr
"""
import argparse
import os
import torch
import torch.nn as nn
import pickle
import numpy as np
import random
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from util.video_dataset import VideoFrameDataset
from models.swiftformer_temporal import (
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
)
# 导入SSIM和PSNR计算
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')
def denormalize(tensor):
"""
将[-1, 1]范围的张量反归一化到[0, 255]范围
Args:
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
Returns:
反归一化后的张量,值在[0, 255]
"""
# clip 到 [-1, 1] 范围
tensor = tensor.clamp(-1, 1)
# [-1, 1] -> [0, 1]
tensor = (tensor + 1) / 2
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
"""
计算MSE, SSIM, PSNR指标
Args:
pred: 预测图像,形状[H, W],值在[0, 255]
target: 目标图像,形状[H, W],值在[0, 255]
debug: 是否输出调试信息
Returns:
mse, ssim_value, psnr_value
"""
# 转换为numpy数组
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
# 确保是2D数组
if pred_np.ndim == 3:
pred_np = pred_np.squeeze(0)
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
if debug:
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
# 计算MSE - 修复错误的tmp公式
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
# 正确公式: 直接计算像素差的平方
mse = np.mean((pred_np - target_np) ** 2)
# 同时计算错误公式的MSE用于对比
tmp = 1 - (pred_np - target_np) / 255 * 2
wrong_mse = np.mean(tmp**2)
if debug:
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
# 计算SSIM (数据范围0-255)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
# 计算PSNR
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
input_frame_indices=None, target_frame_index=None):
"""
保存对比图:输入帧、目标帧、预测帧
Args:
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
target_frame: 目标帧,形状[H, W],值在[0, 255]
pred_frame: 预测帧,形状[H, W],值在[0, 255]
save_path: 保存路径
input_frame_indices: 输入帧的索引列表(可选)
target_frame_index: 目标帧索引(可选)
"""
num_input = len(input_frames)
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
# 绘制输入帧
for i in range(num_input):
ax = axes[i]
ax.imshow(input_frames[i], cmap='gray')
if input_frame_indices is not None:
ax.set_title(f'Input Frame {input_frame_indices[i]}')
else:
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
ax = axes[num_input]
ax.imshow(target_frame, cmap='gray')
if target_frame_index is not None:
ax.set_title(f'Target Frame {target_frame_index}')
else:
ax.set_title('Target')
ax.axis('off')
# 绘制预测帧
ax = axes[num_input + 1]
ax.imshow(pred_frame, cmap='gray')
ax.set_title('Predicted')
ax.axis('off')
# debug print - 改进为更有信息量的输出
if isinstance(pred_frame, np.ndarray):
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
# 检查是否有大量值在127.5附近
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
percent_near_127_5 = np.mean(mask_near_127_5) * 100
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
else:
print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def evaluate_model(model, data_loader, device, args):
"""
评估模型并计算指标
Args:
model: 训练好的模型
data_loader: 数据加载器
device: 设备
args: 命令行参数
Returns:
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
# model.eval()
model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
total_ssim = 0.0
total_psnr = 0.0
total_samples = 0
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
sample_results = []
max_samples_to_save = args.num_samples_to_save
max_samples = args.max_samples
# 用于蓄水池抽样的计数器已处理的样本数不包括因max_samples限制而跳过的样本
sample_count = 0
with torch.no_grad():
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames, _ = model(input_frames)
# 反归一化用于指标计算
pred_denorm = denormalize(pred_frames) # [B, 1, H, W]
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
# 计算每个样本的指标
for i in range(batch_size):
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
break
pred_i = pred_denorm[i] # [1, H, W]
target_i = target_denorm[i] # [1, H, W]
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
if debug_mode:
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
total_mse += mse
total_ssim += ssim_value
total_psnr += psnr_value
total_samples += 1
sample_count += 1
# 构建样本数据字典
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
# 分离输入帧
input_frames_list = []
for j in range(args.num_frames):
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
input_frames_list.append(input_frame_j.cpu().numpy())
sample_data = {
'input_frames': input_frames_list,
'target_frame': target_i.squeeze(0).cpu().numpy(),
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
'metrics': {
'mse': mse,
'ssim': ssim_value,
'psnr': psnr_value
},
'batch_idx': batch_idx,
'sample_idx': i
}
# 蓄水池抽样 (Reservoir Sampling)
if sample_count <= max_samples_to_save:
# 蓄水池未满,直接加入
sample_results.append(sample_data)
else:
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
r = random.randint(0, sample_count - 1)
if r < max_samples_to_save:
sample_results[r] = sample_data
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
print(f"达到最大样本数限制: {max_samples}")
break
# 进度打印
if (batch_idx + 1) % 10 == 0:
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
# 计算平均指标
if total_samples > 0:
avg_mse = float(total_mse / total_samples)
avg_ssim = float(total_ssim / total_samples)
avg_psnr = float(total_psnr / total_samples)
else:
avg_mse = avg_ssim = avg_psnr = 0.0
metrics_dict = {
'mse': avg_mse,
'ssim': avg_ssim,
'psnr': avg_psnr,
'num_samples': total_samples
}
return metrics_dict, sample_results
def main(args):
print("评估参数:", args)
device = torch.device(args.device)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
cudnn.benchmark = True
# 构建数据集
print("构建数据集...")
dataset_val = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=False,
max_interval=args.max_interval
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
shuffle=False,
drop_last=False
)
# 创建模型
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"未知模型: {args.model}")
model.to(device)
# 加载检查点
if args.resume:
print(f"加载检查点: {args.resume}")
try:
# 尝试使用weights_only=False加载PyTorch 2.6+需要)
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
from argparse import Namespace
# 添加安全全局变量
torch.serialization.add_safe_globals([Namespace])
checkpoint = torch.load(args.resume, map_location='cpu')
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 移除'module.'前缀(如果存在)
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict)
else:
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
if any(key.startswith('module.') for key in state_dict.keys()):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print(f"检查点加载成功epoch: {checkpoint.get('epoch', 'unknown')}")
else:
print("警告: 未提供检查点路径,使用随机初始化的模型")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 评估模型
print("开始评估...")
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
# 打印指标
print("\n" + "="*50)
print("评估结果:")
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim']:.6f}")
print(f"PSNR: {metrics['psnr']:.6f} dB")
print(f"样本数量: {metrics['num_samples']}")
print("="*50)
# 保存指标到JSON文件
metrics_file = output_dir / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"指标已保存到: {metrics_file}")
# 保存示例可视化
if sample_results:
print(f"\n保存 {len(sample_results)} 个示例可视化...")
samples_dir = output_dir / 'sample_predictions'
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(sample_results):
save_path = samples_dir / f'sample_{i:03d}.png'
# 生成输入帧索引(假设连续)
input_frame_indices = list(range(1, args.num_frames + 1))
target_frame_index = args.num_frames + 1
save_comparison_figure(
sample['input_frames'],
sample['target_frame'],
sample['pred_frame'],
save_path,
input_frame_indices=input_frame_indices,
target_frame_index=target_frame_index
)
# 保存该样本的指标
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
with open(sample_metrics_file, 'w') as f:
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
print(f"示例可视化已保存到: {samples_dir}")
# 生成汇总报告
report_file = output_dir / 'evaluation_report.txt'
with open(report_file, 'w') as f:
f.write("SwiftFormerTemporal 帧预测评估报告\n")
f.write("="*50 + "\n")
f.write(f"模型: {args.model}\n")
f.write(f"检查点: {args.resume}\n")
f.write(f"数据集: {args.data_path}\n")
f.write(f"输入帧数: {args.num_frames}\n")
f.write(f"帧大小: {args.frame_size}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"样本总数: {metrics['num_samples']}\n\n")
f.write("评估指标:\n")
f.write(f" MSE: {metrics['mse']:.6f}\n")
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
print(f"评估报告已保存到: {report_file}")
print("\n评估完成!")
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估脚本', add_help=False)
# 数据集参数
parser.add_argument('--data-path', default='./videos', type=str,
help='视频数据集路径')
parser.add_argument('--num-frames', default=3, type=int,
help='输入帧数 (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='输入帧大小')
parser.add_argument('--max-interval', default=4, type=int,
help='连续帧之间的最大间隔')
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
parser.add_argument('--use-representation-head', action='store_true',
help='使用表示头进行姿态/速度预测')
parser.add_argument('--representation-dim', default=128, type=int,
help='表示向量的维度')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,
help='评估批次大小')
parser.add_argument('--num-samples-to-save', default=10, type=int,
help='保存可视化的样本数量')
parser.add_argument('--max-samples', default=None, type=int,
help='最大评估样本数None表示全部')
# 系统参数
parser.add_argument('--output-dir', default='./evaluation_results',
help='保存结果的路径')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
help='使用的设备')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='检查点路径')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='在DataLoader中固定CPU内存')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
args = parser.parse_args()
# 确保输出目录存在
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
@@ -47,7 +47,7 @@ def get_args_parser():
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
parser.add_argument('--max-interval', default=4, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
@@ -57,6 +57,7 @@ def get_args_parser():
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
@@ -77,7 +78,7 @@ def get_args_parser():
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
@@ -89,7 +90,7 @@ def get_args_parser():
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
@@ -109,10 +110,10 @@ def get_args_parser():
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
# parser.add_argument('--l1-weight', type=float, default=1.0,
# help='Weight for L1 loss')
# parser.add_argument('--ssim-weight', type=float, default=0.1,
# help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
@@ -212,6 +213,7 @@ def main(args):
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
'use_skip': args.use_skip,
}
if args.model == 'SwiftFormerTemporal_XS':
@@ -326,7 +328,7 @@ def main(args):
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
@@ -374,6 +376,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
header = f'Epoch: [{epoch}]'
print_freq = 10
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
@@ -382,7 +389,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
@@ -395,6 +402,8 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
# 在反向传播前保存梯度用于诊断
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
@@ -402,6 +411,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None:
model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# 检查一个BatchNorm层的运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
break
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
@@ -415,6 +448,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
@@ -451,19 +489,53 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# 第一个批次打印详细诊断信息
if batch_idx == 0:
print(f"[评估诊断] 批次 0:")
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# 检查BatchNorm运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
if module.running_var[0].item() < 1e-6:
print(f" 警告: BatchNorm运行方差接近零!")
break
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():

View File

@@ -11,26 +11,188 @@ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder"""
"""Upsampling block for frame prediction decoder with residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
self.conv = nn.ConvTranspose2d(
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False
bias=True # 启用bias因为移除了BN
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class DecoderBlockWithSkip(nn.Module):
"""Decoder block with skip connection support"""
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 总输入通道 = 输入通道 + skip通道
total_in_channels = in_channels + skip_channels
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if total_in_channels != out_channels or stride != 1:
if stride == 1:
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x, skip_feature=None):
# 如果有skip feature将其与输入拼接
if skip_feature is not None:
# 确保skip特征的空间尺寸与x匹配
if skip_feature.shape[2:] != x.shape[2:]:
# 使用双线性插值进行上采样或下采样
skip_feature = torch.nn.functional.interpolate(
skip_feature,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
x = torch.cat([x, skip_feature], dim=1)
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
"""Improved decoder for frame prediction with better upsampling strategy"""
def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__()
self.use_skip = use_skip
@@ -38,65 +200,109 @@ class FramePredictionDecoder(nn.Module):
decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
self.blocks.append(DecoderBlock(
if use_skip:
# 使用支持skip connections的block
# 第一个block从bottleneck到stage4使用大步长stride=4skip来自stage3
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[0], decoder_dims[1],
skip_channels=embed_dims[3], # stage3的通道数
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
# 第二个blockstage4到stage3stride=2skip来自stage2
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[1], decoder_dims[2],
skip_channels=embed_dims[2], # stage2的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage4 to stage3
# 第三个blockstage3到stage2stride=2skip来自stage1
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[2], decoder_dims[3],
skip_channels=embed_dims[1], # stage1的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstage2到stage1stride=2skip来自stage0
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[3], 64, # 输出到64通道
skip_channels=embed_dims[0], # stage0的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
else:
# 使用普通的DecoderBlock第一个block使用大步长
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage3 to stage2
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (now 8x upsampling total with stride 4)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=4, padding=1, output_padding=3
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
# 第四个block增加到64通道
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=2, padding=1, output_padding=1
))
# If using skip connections, we need to adjust input channels for each block
if use_skip:
# We'll modify the first three blocks to accept concatenated features
# Instead of modifying existing blocks, we'll replace them with custom blocks
# For simplicity, we'll keep the same architecture but forward will handle concatenation
pass
# 改进的最终输出层:不使用反卷积,只进行特征精炼
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
# 移除Tanh让输出在任意范围由损失函数和归一化处理
)
def forward(self, x, skip_features=None):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches encoder dims
"""
if self.use_skip and skip_features is not None:
# Ensure we have exactly 3 skip features (for the first three blocks)
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
if self.use_skip:
if skip_features is None:
raise ValueError("skip_features must be provided when use_skip=True")
for i, block in enumerate(self.blocks):
if self.use_skip and skip_features is not None and i < 3:
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
# 确保有4个skip features
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
# 反转顺序以匹配解码器stage3, stage2, stage1, stage0
skip_features = skip_features[::-1]
# 调整skip特征的尺寸以匹配新的上采样策略
adjusted_skip_features = []
for i, skip in enumerate(skip_features):
if skip is not None:
# 计算目标尺寸4, 2, 2, 2倍上采样
upsample_factors = [4, 2, 2, 2]
target_height = x.shape[2] * upsample_factors[i]
target_width = x.shape[3] * upsample_factors[i]
if skip.shape[2:] != (target_height, target_width):
skip = torch.nn.functional.interpolate(
skip,
size=(target_height, target_width),
mode='bilinear',
align_corners=False
)
adjusted_skip_features.append(skip)
# 四个block使用skip connections
for i in range(4):
x = self.blocks[i](x, adjusted_skip_features[i])
else:
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
@@ -110,6 +316,7 @@ class SwiftFormerTemporal(nn.Module):
model_name='XS',
num_frames=3,
use_decoder=True,
use_skip=True, # 新增是否使用skip connections
use_representation_head=False,
representation_dim=128,
return_features=False,
@@ -123,6 +330,7 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_skip = use_skip # 保存skip connections设置
self.use_representation_head = use_representation_head
self.return_features = return_features
@@ -155,7 +363,11 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1,
use_skip=use_skip # 传递skip connections设置
)
# Representation head for pose/velocity prediction
if use_representation_head:
@@ -173,22 +385,31 @@ class SwiftFormerTemporal(nn.Module):
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
# 使用Kaiming初始化适合ReLU/LeakyReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features:
if self.return_features or self.use_skip:
features = []
stage_idx = 0
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
# 收集每个stage的输出stage0, stage1, stage2, stage3
# 根据SwiftFormer结构stage在索引0,2,4,6位置
if idx in [0, 2, 4, 6]:
features.append(x)
stage_idx += 1
return x, features
else:
for block in self.network:
@@ -208,7 +429,7 @@ class SwiftFormerTemporal(nn.Module):
"""
# Encode
x = self.patch_embed(x)
if self.return_features:
if self.return_features or self.use_skip:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
@@ -222,6 +443,22 @@ class SwiftFormerTemporal(nn.Module):
# Decode to frame
pred_frame = None
if self.use_decoder:
if self.use_skip:
# 提取用于skip connections的特征
# features包含所有stage的输出我们需要stage0, stage1, stage2, stage3
# 根据SwiftFormer结构应该有4个stage特征
if len(features) >= 4:
# 取四个stage的特征stage0, stage1, stage2, stage3
skip_features = [features[0], features[1], features[2], features[3]]
else:
# 如果特征不够,使用可用的特征
skip_features = features[:4]
# 如果特征仍然不够使用None填充
while len(skip_features) < 4:
skip_features.append(None)
pred_frame = self.decoder(x, skip_features)
else:
pred_frame = self.decoder(x)
if self.return_features:
@@ -231,14 +468,14 @@ class SwiftFormerTemporal(nn.Module):
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)

View File

@@ -1,26 +0,0 @@
#!/usr/bin/env bash
# Simple multi-GPU training script for SwiftFormerTemporal
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
NUM_GPUS=${1:-2}
shift
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
# Set environment variables for distributed training
export MASTER_PORT=12345
export MASTER_ADDR=localhost
export WORLD_SIZE=$NUM_GPUS
# Launch training
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
--data-path "./videos" \
--model SwiftFormerTemporal_XS \
--batch-size 32 \
--epochs 100 \
--lr 1e-3 \
--output-dir "./temporal_output_multi" \
--num-workers 8 \
--pin-mem \
"$@"

View File

@@ -1,45 +0,0 @@
import torch
def test_cuda_availability():
"""全面测试CUDA可用性"""
print("="*50)
print("PyTorch CUDA 测试")
print("="*50)
# 基本信息
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
print("CUDA不可用可能原因")
print("1. 未安装CUDA驱动")
print("2. 安装的是CPU版本的PyTorch")
print("3. CUDA版本与PyTorch不匹配")
return False
# 设备信息
device_count = torch.cuda.device_count()
print(f"发现 {device_count} 个CUDA设备")
for i in range(device_count):
print(f"\n设备 {i}:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
# 简单张量测试
print("\n运行CUDA测试...")
try:
x = torch.randn(3, 3).cuda()
y = torch.randn(3, 3).cuda()
z = x + y
print("CUDA计算测试: 成功!")
print(f"设备上的张量形状: {z.shape}")
return True
except Exception as e:
print(f"CUDA计算测试失败: {e}")
return False
if __name__ == "__main__":
test_cuda_availability()

View File

@@ -1,33 +0,0 @@
#!/usr/bin/env python3
"""
测试 timm 导入是否正常工作
"""
import sys
print("Python version:", sys.version)
try:
from timm.layers import to_2tuple, DropPath, trunc_normal_
from timm.models import register_model
print("✓ 成功导入 timm.layers.to_2tuple")
print("✓ 成功导入 timm.layers.DropPath")
print("✓ 成功导入 timm.layers.trunc_normal_")
print("✓ 成功导入 timm.models.register_model")
except ImportError as e:
print(f"✗ 导入失败: {e}")
sys.exit(1)
try:
from models.swiftformer import SwiftFormer_XS
print("✓ 成功导入 SwiftFormer_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
sys.exit(1)
try:
from models.swiftformer_temporal import SwiftFormerTemporal_XS
print("✓ 成功导入 SwiftFormerTemporal_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
sys.exit(1)
print("\n✅ 所有导入测试通过!")

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
self.is_train = is_train
self.max_interval = max_interval
# Collect all video folders
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
for video_idx, frame_files in enumerate(self.video_frame_files):
# Minimum frames needed considering max interval
min_frames_needed = num_frames * max_interval + 1
if len(frame_files) < min_frames_needed:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
# Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
else:
self.transform = transform
# Normalization for Y channel (single channel)
# Compute average of ImageNet RGB means and stds
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[y_mean],
std=[y_std]
)
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Convert pixel values [0, 255] to [-1, 1]
# This matches the model's tanh output range
self.normalize = None # We'll handle normalization manually
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self):
"""Default transform with augmentation for training"""
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_files = self.video_frame_files[video_idx]
if frame_idx < 0 or frame_idx >= len(frame_files):
raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors, normalize, and convert to grayscale (Y channel)
# Convert to tensors and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W]
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# class SyntheticVideoDataset(Dataset):
# """
# Synthetic dataset for testing - generates random frames
# """
# def __init__(self,
# num_samples: int = 1000,
# num_frames: int = 3,
# frame_size: int = 224,
# is_train: bool = True):
# self.num_samples = num_samples
# self.num_frames = num_frames
# self.frame_size = frame_size
# self.is_train = is_train
# Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[y_mean],
std=[y_std]
)
# # Normalization for Y channel (single channel)
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
# y_std = (0.229 + 0.224 + 0.225) / 3.0
# self.normalize = transforms.Normalize(
# mean=[y_mean],
# std=[y_std]
# )
def __len__(self):
return self.num_samples
# def __len__(self):
# return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
# def __getitem__(self, idx):
# # Generate random "frames" (noise with temporal correlation)
# input_frames = []
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# for i in range(self.num_frames):
# # Add some temporal correlation
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# frame = torch.clamp(frame, -1, 1)
# input_frames.append(self.normalize(frame))
# prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# # Target frame (next in sequence)
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# target_frame = torch.clamp(target_frame, -1, 1)
# target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# # Concatenate inputs
# input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
# # Temporal index
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
# return input_concatenated, target_tensor, temporal_idx