Compare commits

...

9 Commits

10 changed files with 1271 additions and 475 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.vscode/
__pycache__/
venv/
runs/

58
dist_temporal_train.sh Executable file
View File

@@ -0,0 +1,58 @@
#!/usr/bin/env bash
# Distributed training script for SwiftFormerTemporal
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
DATA_PATH=$1
NUM_GPUS=$2
# Shift arguments to pass remaining options to python script
shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-128}
EPOCHS=${EPOCHS:-100}
# LR=${LR:-1e-3}
LR=${LR:-0.01}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"
echo "Data path: $DATA_PATH"
echo "Model: $MODEL"
echo "Batch size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "Output dir: $OUTPUT_DIR"
# Check if torch.distributed.launch or torchrun should be used
# For newer PyTorch versions (>=1.9), torchrun is recommended
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
echo "PyTorch version: $PYTHON_VERSION"
# Use torchrun for newer PyTorch versions
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
echo "Using torchrun (PyTorch >=1.10)"
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
else
echo "Using torch.distributed.launch"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
fi
# For single-node multi-GPU training with specific options:
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
echo "Training completed. Check logs in $OUTPUT_DIR"

484
evaluate_temporal.py Normal file
View File

@@ -0,0 +1,484 @@
"""
评估脚本 for SwiftFormerTemporal frame prediction
输出预测图注意反归一化以及对应指标mse&ssim&psnr
"""
import argparse
import os
import torch
import torch.nn as nn
import pickle
import numpy as np
import random
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from util.video_dataset import VideoFrameDataset
from models.swiftformer_temporal import (
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
)
# 导入SSIM和PSNR计算
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')
def denormalize(tensor):
"""
将[-1, 1]范围的张量反归一化到[0, 255]范围
Args:
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
Returns:
反归一化后的张量,值在[0, 255]
"""
# clip 到 [-1, 1] 范围
tensor = tensor.clamp(-1, 1)
# [-1, 1] -> [0, 1]
tensor = (tensor + 1) / 2
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
"""
计算MSE, SSIM, PSNR指标
Args:
pred: 预测图像,形状[H, W],值在[0, 255]
target: 目标图像,形状[H, W],值在[0, 255]
debug: 是否输出调试信息
Returns:
mse, ssim_value, psnr_value
"""
# 转换为numpy数组
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
# 确保是2D数组
if pred_np.ndim == 3:
pred_np = pred_np.squeeze(0)
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
# if debug:
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
mse = np.mean((pred_np - target_np) ** 2)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
input_frame_indices=None, target_frame_index=None):
"""
保存对比图:输入帧、目标帧、预测帧
Args:
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
target_frame: 目标帧,形状[H, W],值在[0, 255]
pred_frame: 预测帧,形状[H, W],值在[0, 255]
save_path: 保存路径
input_frame_indices: 输入帧的索引列表(可选)
target_frame_index: 目标帧索引(可选)
"""
num_input = len(input_frames)
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
# 绘制输入帧
for i in range(num_input):
ax = axes[i]
ax.imshow(input_frames[i], cmap='gray')
if input_frame_indices is not None:
ax.set_title(f'Input Frame {input_frame_indices[i]}')
else:
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
ax = axes[num_input]
ax.imshow(target_frame, cmap='gray')
if target_frame_index is not None:
ax.set_title(f'Target Frame {target_frame_index}')
else:
ax.set_title('Target')
ax.axis('off')
# 绘制预测帧
ax = axes[num_input + 1]
ax.imshow(pred_frame, cmap='gray')
ax.set_title('Predicted')
ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def evaluate_model(model, data_loader, device, args):
"""
评估模型并计算指标
Args:
model: 训练好的模型
data_loader: 数据加载器
device: 设备
args: 命令行参数
Returns:
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
model.eval()
# model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
total_ssim = 0.0
total_psnr = 0.0
total_samples = 0
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
sample_results = []
max_samples_to_save = args.num_samples_to_save
max_samples = args.max_samples
# 用于蓄水池抽样的计数器已处理的样本数不包括因max_samples限制而跳过的样本
sample_count = 0
with torch.no_grad():
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames = model(input_frames)
# 反归一化用于指标计算
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
# 计算每个样本的指标
for i in range(batch_size):
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
break
pred_i = pred_denorm[i] # [1, H, W]
target_i = target_denorm[i] # [1, H, W]
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
# if debug_mode:
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse
total_ssim += ssim_value
total_psnr += psnr_value
total_samples += 1
sample_count += 1
# 构建样本数据字典
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
# 分离输入帧
input_frames_list = []
for j in range(args.num_frames):
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
input_frames_list.append(input_frame_j.cpu().numpy())
sample_data = {
'input_frames': input_frames_list,
'target_frame': target_i.squeeze(0).cpu().numpy(),
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
'metrics': {
'mse': mse,
'ssim': ssim_value,
'psnr': psnr_value
},
'batch_idx': batch_idx,
'sample_idx': i
}
# 蓄水池抽样 (Reservoir Sampling)
if sample_count <= max_samples_to_save:
# 蓄水池未满,直接加入
sample_results.append(sample_data)
else:
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
r = random.randint(0, sample_count - 1)
if r < max_samples_to_save:
sample_results[r] = sample_data
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
print(f"达到最大样本数限制: {max_samples}")
break
# 进度打印
if (batch_idx + 1) % 10 == 0:
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
# 计算平均指标
if total_samples > 0:
avg_mse = float(total_mse / total_samples)
avg_ssim = float(total_ssim / total_samples)
avg_psnr = float(total_psnr / total_samples)
else:
avg_mse = avg_ssim = avg_psnr = 0.0
metrics_dict = {
'mse': avg_mse,
'ssim': avg_ssim,
'psnr': avg_psnr,
'num_samples': total_samples
}
return metrics_dict, sample_results
def main(args):
print("评估参数:", args)
device = torch.device(args.device)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
cudnn.benchmark = True
# 构建数据集
print("构建数据集...")
dataset_val = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=False,
max_interval=args.max_interval
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
shuffle=False,
drop_last=False
)
# 创建模型
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"未知模型: {args.model}")
model.to(device)
# 加载检查点
if args.resume:
print(f"加载检查点: {args.resume}")
try:
# 尝试使用weights_only=False加载PyTorch 2.6+需要)
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 移除'module.'前缀(如果存在)
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict)
else:
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
if any(key.startswith('module.') for key in state_dict.keys()):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print(f"检查点加载成功epoch: {checkpoint.get('epoch', 'unknown')}")
else:
print("警告: 未提供检查点路径,使用随机初始化的模型")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 评估模型
print("开始评估...")
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
# 打印指标
print("\n" + "="*50)
print("评估结果:")
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim']:.6f}")
print(f"PSNR: {metrics['psnr']:.6f} dB")
print(f"样本数量: {metrics['num_samples']}")
print("="*50)
# 保存指标到JSON文件
metrics_file = output_dir / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"指标已保存到: {metrics_file}")
# 保存示例可视化
if sample_results:
print(f"\n保存 {len(sample_results)} 个示例可视化...")
samples_dir = output_dir / 'sample_predictions'
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(sample_results):
save_path = samples_dir / f'sample_{i:03d}.png'
# 生成输入帧索引(假设连续)
input_frame_indices = list(range(1, args.num_frames + 1))
target_frame_index = args.num_frames + 1
save_comparison_figure(
sample['input_frames'],
sample['target_frame'],
sample['pred_frame'],
save_path,
input_frame_indices=input_frame_indices,
target_frame_index=target_frame_index
)
# 保存该样本的指标
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
with open(sample_metrics_file, 'w') as f:
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
print(f"示例可视化已保存到: {samples_dir}")
# 生成汇总报告
report_file = output_dir / 'evaluation_report.txt'
with open(report_file, 'w') as f:
f.write("SwiftFormerTemporal 帧预测评估报告\n")
f.write("="*50 + "\n")
f.write(f"模型: {args.model}\n")
f.write(f"检查点: {args.resume}\n")
f.write(f"数据集: {args.data_path}\n")
f.write(f"输入帧数: {args.num_frames}\n")
f.write(f"帧大小: {args.frame_size}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"样本总数: {metrics['num_samples']}\n\n")
f.write("评估指标:\n")
f.write(f" MSE: {metrics['mse']:.6f}\n")
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
print(f"评估报告已保存到: {report_file}")
print("\n评估完成!")
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估脚本', add_help=False)
# 数据集参数
parser.add_argument('--data-path', default='./videos', type=str,
help='视频数据集路径')
parser.add_argument('--num-frames', default=3, type=int,
help='输入帧数 (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='输入帧大小')
parser.add_argument('--max-interval', default=4, type=int,
help='连续帧之间的最大间隔')
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,
help='评估批次大小')
parser.add_argument('--num-samples-to-save', default=10, type=int,
help='保存可视化的样本数量')
parser.add_argument('--max-samples', default=None, type=int,
help='最大评估样本数None表示全部')
# 系统参数
parser.add_argument('--output-dir', default='./evaluation_results',
help='保存结果的路径')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
help='使用的设备')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='检查点路径')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='在DataLoader中固定CPU内存')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
args = parser.parse_args()
# 确保输出目录存在
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -6,8 +6,10 @@ import datetime
import numpy as np import numpy as np
import time import time
import torch import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import json import json
import os
from pathlib import Path from pathlib import Path
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
@@ -17,8 +19,15 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import * from util import *
from models import * from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss # from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser(): def get_args_parser():
@@ -34,34 +43,68 @@ def get_args_parser():
help='Number of input frames (T)') help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int, parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size') help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int, parser.add_argument('--max-interval', default=10, type=int,
help='Maximum interval between consecutive frames') help='Maximum interval between consecutive frames')
# Model parameters # Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train') help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters # Training parameters
parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)') # Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05, parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)') help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters # Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0, parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss') help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1, parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss') help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0, # parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss') # help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1, # parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss') # help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true', parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss') help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true', parser.add_argument('--no-ssim', action='store_true',
@@ -78,7 +121,7 @@ def get_args_parser():
help='start epoch') help='start epoch')
parser.add_argument('--eval', action='store_true', parser.add_argument('--eval', action='store_true',
help='Perform evaluation only') help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--num-workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true', parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader') help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
@@ -90,26 +133,26 @@ def get_args_parser():
parser.add_argument('--dist-url', default='env://', parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training') help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser return parser
def build_dataset(is_train, args): def build_dataset(is_train, args):
"""Build video frame dataset""" """Build video frame dataset"""
if args.dataset_type == 'synthetic': dataset = VideoFrameDataset(
dataset = SyntheticVideoDataset( root_dir=args.data_path,
num_samples=1000 if is_train else 200, num_frames=args.num_frames,
num_frames=args.num_frames, frame_size=args.frame_size,
frame_size=args.frame_size, is_train=is_train,
is_train=is_train max_interval=args.max_interval
) )
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset return dataset
@@ -159,8 +202,6 @@ def main(args):
print(f"Creating model: {args.model}") print(f"Creating model: {args.model}")
model_kwargs = { model_kwargs = {
'num_frames': args.num_frames, 'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
} }
if args.model == 'SwiftFormerTemporal_XS': if args.model == 'SwiftFormerTemporal_XS':
@@ -203,14 +244,18 @@ def main(args):
# Create scheduler # Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer) lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function # Create loss function - simple MSE for Y channel prediction
criterion = MultiTaskLoss( class MSELossWrapper(nn.Module):
frame_weight=args.frame_weight, def __init__(self):
contrastive_weight=args.contrastive_weight, super().__init__()
l1_weight=args.l1_weight, self.mse = nn.MSELoss()
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive def forward(self, pred_frame, target_frame, temporal_indices=None):
) loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint # Resume from checkpoint
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@@ -219,7 +264,7 @@ def main(args):
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True) args.resume, map_location='cpu', check_hash=True)
else: else:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
@@ -231,6 +276,21 @@ def main(args):
if 'scaler' in checkpoint: if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler']) loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval: if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device) test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}") print(f"Test stats: {test_stats}")
@@ -239,20 +299,24 @@ def main(args):
print(f"Start training for {args.epochs} epochs") print(f"Start training for {args.epochs} epochs")
start_time = time.time() start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch( train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
model_ema=model_ema model_ema=model_ema, writer=writer,
global_step=global_step, args=args
) )
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
# Save checkpoint # Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({ utils.save_on_master({
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
@@ -266,10 +330,10 @@ def main(args):
# Evaluate # Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1: if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device) test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}") print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats # Log stats to text file
log_stats = { log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()}, **{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()},
@@ -284,29 +348,40 @@ def main(args):
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}') print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs): clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]' header = f'Epoch: [{epoch}]'
print_freq = 10 print_freq = 10
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for input_frames, target_frames, temporal_indices in metric_logger.log_every( for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
data_loader, print_freq, header): metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True) input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass # Forward pass
with torch.cuda.amp.autocast(): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
loss_value = loss.item() loss_value = loss.item()
@@ -315,6 +390,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
raise ValueError(f"Loss is {loss_value}") raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad() optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters()) parameters=model.parameters())
@@ -322,36 +398,131 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# # 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# # 检查一个BatchNorm层的运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# break
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics # Update metrics
metric_logger.update(loss=loss_value) metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items(): for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger) print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad() @torch.no_grad()
def evaluate(data_loader, model, criterion, device): def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = 'Test:'
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True) input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output # Compute output
with torch.cuda.amp.autocast(): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# # 第一个批次打印详细诊断信息
# if batch_idx == 0:
# print(f"[评估诊断] 批次 0:")
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# # 检查BatchNorm运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# if module.running_var[0].item() < 1e-6:
# print(f" 警告: BatchNorm运行方差接近零!")
# break
# Update metrics # Update metrics
metric_logger.update(loss=loss.item()) metric_logger.update(loss=loss.item())
for k, v in loss_dict.items(): for k, v in loss_dict.items():
@@ -359,6 +530,12 @@ def evaluate(data_loader, model, criterion, device):
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger) print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@@ -370,4 +547,4 @@ if __name__ == '__main__':
if args.output_dir: if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args) main(args)

View File

@@ -6,9 +6,9 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models import register_model
from timm.models.layers.helpers import to_2tuple from timm.layers import to_2tuple
import einops import einops
SwiftFormer_width = { SwiftFormer_width = {

View File

@@ -7,96 +7,117 @@ from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width, SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage stem, Embedding, Stage
) )
from timm.models.layers import DropPath, trunc_normal_ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder""" """Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__() super().__init__()
self.conv = nn.ConvTranspose2d( # 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels, in_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
output_padding=output_padding, output_padding=output_padding,
bias=False bias=False # 禁用bias因为使用BN
) )
self.bn = nn.BatchNorm2d(out_channels) self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
return self.relu(self.bn(self.conv(x))) # 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module): class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections""" """Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=3, use_skip=False): def __init__(self, embed_dims, output_channels=1):
super().__init__() super().__init__()
self.use_skip = use_skip # Define decoder dimensions independently (no skip connections)
# Reverse the embed_dims for decoder start_dim = embed_dims[-1]
decoder_dims = embed_dims[::-1] decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1], decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage4 to stage3 # 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2], decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage3 to stage2 # 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3], decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage2 to original resolution (4x upsampling total) # 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(nn.Sequential( self.blocks.append(DecoderBlock(
nn.ConvTranspose2d( decoder_dims[3], 64,
decoder_dims[3], 32, kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
)) ))
# If using skip connections, we need to adjust input channels for each block self.final_block = nn.Sequential(
if use_skip: nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
# We'll modify the first three blocks to accept concatenated features nn.ReLU(inplace=True),
# Instead of modifying existing blocks, we'll replace them with custom blocks nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
# For simplicity, we'll keep the same architecture but forward will handle concatenation nn.ReLU(inplace=True),
pass nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
def forward(self, x, skip_features=None): def forward(self, x):
""" """
Args: Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32] x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
""" """
if self.use_skip and skip_features is not None: # 不使用skip connections
# Ensure we have exactly 3 skip features (for the first three blocks) for i in range(4):
assert len(skip_features) == 3, "Need 3 skip features for skip connections" x = self.blocks[i](x)
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
for i, block in enumerate(self.blocks): # 最终输出层:只进行特征精炼,不上采样
if self.use_skip and skip_features is not None and i < 3: x = self.final_block(x)
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
return x return x
@@ -104,15 +125,12 @@ class SwiftFormerTemporal(nn.Module):
""" """
SwiftFormer with temporal input for frame prediction. SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only) Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation Output: predicted frame [B, 1, H, W] and optional representation
""" """
def __init__(self, def __init__(self,
model_name='XS', model_name='XS',
num_frames=3, num_frames=3,
use_decoder=True, use_decoder=True,
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs): **kwargs):
super().__init__() super().__init__()
@@ -123,8 +141,6 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration # Store configuration
self.num_frames = num_frames self.num_frames = num_frames
self.use_decoder = use_decoder self.use_decoder = use_decoder
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel) # Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames in_channels = num_frames
@@ -155,79 +171,51 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder # Frame prediction decoder
if use_decoder: if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3) self.decoder = FramePredictionDecoder(
embed_dims,
# Representation head for pose/velocity prediction output_channels=1
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
) )
else:
self.representation_head = None
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02) # 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)): elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x): def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output""" for block in self.network:
if self.return_features: x = block(x)
features = [] return x
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
if idx in [0, 2, 4, 6]:
features.append(x)
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x: input frames of shape [B, num_frames, H, W] x: input frames of shape [B, num_frames, H, W]
Returns: Returns:
If return_features is False: pred_frame: predicted frame [B, 1, H, W] (or None)
pred_frame: predicted frame [B, 3, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
""" """
# Encode # Encode
x = self.patch_embed(x) x = self.patch_embed(x)
if self.return_features: x = self.forward_tokens(x)
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x) x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame # Decode to frame
pred_frame = None pred_frame = None
if self.use_decoder: if self.use_decoder:
pred_frame = self.decoder(x) pred_frame = self.decoder(x)
if self.return_features: return pred_frame
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes # Factory functions for different model sizes

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,182 +0,0 @@
"""
Loss functions for frame prediction and representation learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSIMLoss(nn.Module):
"""
Structural Similarity Index Measure Loss
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
"""
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 3
self.window = self.create_window(window_size, self.channel)
def create_window(self, window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def forward(self, img1, img2):
# Ensure window is on correct device
if self.window.device != img1.device:
self.window = self.window.to(img1.device)
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if self.size_average:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map.mean(1).mean(1).mean(1)
class FramePredictionLoss(nn.Module):
"""
Combined loss for frame prediction
"""
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
super().__init__()
self.l1_weight = l1_weight
self.ssim_weight = ssim_weight
self.use_ssim = use_ssim
self.l1_loss = nn.L1Loss()
if use_ssim:
self.ssim_loss = SSIMLoss()
def forward(self, pred, target):
"""
Args:
pred: predicted frame [B, 3, H, W] in range [-1, 1]
target: target frame [B, 3, H, W] in range [-1, 1]
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# L1 loss
l1_loss = self.l1_loss(pred, target)
loss_dict['l1'] = l1_loss
total_loss = self.l1_weight * l1_loss
# SSIM loss
if self.use_ssim:
ssim_loss = self.ssim_loss(pred, target)
loss_dict['ssim'] = ssim_loss
total_loss += self.ssim_weight * ssim_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for representation learning
Positive pairs: representations from adjacent frames
Negative pairs: representations from distant frames
"""
def __init__(self, temperature=0.1, margin=1.0):
super().__init__()
self.temperature = temperature
self.margin = margin
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
def forward(self, representations, temporal_indices):
"""
Args:
representations: [B, D] representation vectors
temporal_indices: [B] temporal indices of each sample
Returns:
contrastive_loss
"""
batch_size = representations.size(0)
# Compute similarity matrix
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
# Create positive mask (adjacent frames)
indices_expanded = temporal_indices.unsqueeze(0)
diff = torch.abs(indices_expanded - indices_expanded.T)
positive_mask = (diff == 1).float()
# Create negative mask (distant frames)
negative_mask = (diff > 2).float()
# Positive loss
pos_sim = sim_matrix * positive_mask
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
# Negative loss (push apart)
neg_sim = sim_matrix * negative_mask
neg_loss = torch.relu(neg_sim - self.margin).mean()
return pos_loss + 0.1 * neg_loss
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining frame prediction and representation learning
"""
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
super().__init__()
self.frame_weight = frame_weight
self.contrastive_weight = contrastive_weight
self.use_contrastive = use_contrastive
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
if use_contrastive:
self.contrastive_loss = ContrastiveLoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
"""
Args:
pred_frame: predicted frame [B, 3, H, W]
target_frame: target frame [B, 3, H, W]
representations: [B, D] representation vectors (optional)
temporal_indices: [B] temporal indices (optional)
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# Frame prediction loss
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
total_loss = self.frame_weight * frame_loss
# Contrastive loss (if representations provided)
if self.use_contrastive and representations is not None and temporal_indices is not None:
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
loss_dict['contrastive'] = contrastive_loss
total_loss += self.contrastive_weight * contrastive_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict

View File

@@ -47,28 +47,40 @@ class VideoFrameDataset(Dataset):
self.frame_size = frame_size self.frame_size = frame_size
self.is_train = is_train self.is_train = is_train
self.max_interval = max_interval self.max_interval = max_interval
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders # Collect all video folders and their frame files
self.video_folders = [] self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir(): for item in self.root_dir.iterdir():
if item.is_dir(): if item.is_dir():
self.video_folders.append(item) self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0: if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}") raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx) # Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = [] self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders): for video_idx, frame_files in enumerate(self.video_frame_files):
# Get all frame files # Minimum frames needed considering max interval
frame_files = sorted([f for f in video_folder.iterdir() min_frames_needed = num_frames * max_interval + 1
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) if len(frame_files) < min_frames_needed:
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames continue # Skip videos with insufficient frames
# Add all possible starting positions # Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames): # Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx)) self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0: if len(self.frame_indices) == 0:
@@ -80,11 +92,12 @@ class VideoFrameDataset(Dataset):
else: else:
self.transform = transform self.transform = transform
# Normalization (ImageNet stats) # Simple normalization to [-1, 1] range (不使用ImageNet标准化)
self.normalize = transforms.Normalize( # Convert pixel values [0, 255] to [-1, 1]
mean=[0.485, 0.456, 0.406], # This matches the model's tanh output range
std=[0.229, 0.224, 0.225] self.normalize = None # We'll handle normalization manually
)
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self): def _default_transform(self):
"""Default transform with augmentation for training""" """Default transform with augmentation for training"""
@@ -102,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image""" """Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx] frame_files = self.video_frame_files[video_idx]
frame_files = sorted([f for f in video_folder.iterdir() if frame_idx < 0 or frame_idx >= len(frame_files):
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx] frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB') return Image.open(frame_path).convert('RGB')
@@ -114,8 +130,8 @@ class VideoFrameDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Returns: Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [3, H, W] target frame to predict target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss) temporal_idx: temporal index of target frame (for contrastive loss)
""" """
video_idx, start_idx = self.frame_indices[idx] video_idx, start_idx = self.frame_indices[idx]
@@ -141,69 +157,77 @@ class VideoFrameDataset(Dataset):
if self.transform: if self.transform:
target_frame = self.transform(target_frame) target_frame = self.transform(target_frame)
# Convert to tensors and normalize # Convert to tensors and convert to grayscale (Y channel)
input_tensors = [] input_tensors = []
for frame in input_frames: for frame in input_frames:
tensor = transforms.ToTensor()(frame) tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
tensor = self.normalize(tensor) # Convert RGB to grayscale using weighted sum
input_tensors.append(tensor) # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_tensor = self.normalize(target_tensor) target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension # Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss) # Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset): # class SyntheticVideoDataset(Dataset):
""" # """
Synthetic dataset for testing - generates random frames # Synthetic dataset for testing - generates random frames
""" # """
def __init__(self, # def __init__(self,
num_samples: int = 1000, # num_samples: int = 1000,
num_frames: int = 3, # num_frames: int = 3,
frame_size: int = 224, # frame_size: int = 224,
is_train: bool = True): # is_train: bool = True):
self.num_samples = num_samples # self.num_samples = num_samples
self.num_frames = num_frames # self.num_frames = num_frames
self.frame_size = frame_size # self.frame_size = frame_size
self.is_train = is_train # self.is_train = is_train
# Normalization # # Normalization for Y channel (single channel)
self.normalize = transforms.Normalize( # y_mean = (0.485 + 0.456 + 0.406) / 3.0
mean=[0.485, 0.456, 0.406], # y_std = (0.229 + 0.224 + 0.225) / 3.0
std=[0.229, 0.224, 0.225] # self.normalize = transforms.Normalize(
) # mean=[y_mean],
# std=[y_std]
# )
def __len__(self): # def __len__(self):
return self.num_samples # return self.num_samples
def __getitem__(self, idx): # def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation) # # Generate random "frames" (noise with temporal correlation)
input_frames = [] # input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 # prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames): # for i in range(self.num_frames):
# Add some temporal correlation # # Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1) # frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame)) # input_frames.append(self.normalize(frame))
prev_frame = frame # prev_frame = frame
# Target frame (next in sequence) # # Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1) # target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame) # target_tensor = self.normalize(target_frame)
# Concatenate inputs # # Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0) # input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index # # Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) # temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx # return input_concatenated, target_tensor, temporal_idx

303
video_preprocessor.py Normal file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
视频预处理脚本 - 将MP4视频转换为224x224帧图像
支持多线程并发处理、进度条显示和中断恢复功能
"""
import os
import sys
import json
import argparse
import subprocess
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from typing import List, Dict, Optional
class VideoPreprocessor:
"""视频预处理器,支持多线程和中断恢复"""
def __init__(self,
input_dir: str,
output_dir: str,
frame_size: int = 224,
fps: int = 30,
num_workers: int = 4,
quality: int = 2,
resume: bool = True):
"""
初始化预处理器
Args:
input_dir: 输入视频目录
output_dir: 输出帧目录
frame_size: 帧大小(正方形)
fps: 提取帧率
num_workers: 并发工作线程数
quality: JPEG质量 (1-31, 数值越小质量越高)
resume: 是否启用中断恢复
"""
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.frame_size = frame_size
self.fps = fps
self.num_workers = num_workers
self.quality = quality
self.resume = resume
# 状态文件路径
self.state_file = self.output_dir / ".preprocessing_state.json"
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化状态
self.state = self._load_state()
# 收集所有视频文件
self.video_files = self._collect_video_files()
def _load_state(self) -> Dict:
"""加载处理状态"""
if self.resume and self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
print(f"警告: 无法读取状态文件,将重新开始处理")
return {
"completed": [],
"failed": [],
"total_processed": 0,
"start_time": None,
"last_update": None
}
def _save_state(self):
"""保存处理状态"""
self.state["last_update"] = time.time()
try:
with open(self.state_file, 'w') as f:
json.dump(self.state, f, indent=2)
except IOError as e:
print(f"警告: 无法保存状态文件: {e}")
def _collect_video_files(self) -> List[Path]:
"""收集所有需要处理的视频文件"""
video_files = []
for file_path in self.input_dir.glob("*.mp4"):
if file_path.name not in self.state["completed"]:
video_files.append(file_path)
return sorted(video_files)
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
"""解析视频文件名使用完整文件名作为ID"""
name_without_ext = video_path.stem
# 直接使用完整文件名作为ID确保每个mp4文件有独立的输出目录
return {
"video_id": name_without_ext,
"start_frame": "unknown",
"end_frame": "unknown",
"full_name": name_without_ext
}
def _extract_frames(self, video_path: Path) -> bool:
"""提取单个视频的帧"""
try:
# 解析视频名称
video_info = self._parse_video_name(video_path)
output_subdir = self.output_dir / video_info["video_id"]
output_subdir.mkdir(exist_ok=True)
# 构建FFmpeg命令
output_pattern = output_subdir / "frame_%04d.jpg"
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
"-q:v", str(self.quality),
"-y", # 覆盖输出文件
str(output_pattern)
]
# 执行FFmpeg命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
return False
# 验证输出帧数量
output_frames = list(output_subdir.glob("frame_*.jpg"))
if len(output_frames) == 0:
print(f"警告: {video_path.name} 没有生成任何帧")
return False
return True
except subprocess.TimeoutExpired:
print(f"超时处理 {video_path.name}")
return False
except Exception as e:
print(f"处理 {video_path.name} 时发生错误: {e}")
return False
def _process_video(self, video_path: Path) -> tuple[bool, str]:
"""处理单个视频文件"""
video_name = video_path.name
try:
success = self._extract_frames(video_path)
if success:
self.state["completed"].append(video_name)
if video_name in self.state["failed"]:
self.state["failed"].remove(video_name)
return True, video_name
else:
self.state["failed"].append(video_name)
return False, video_name
except Exception as e:
print(f"处理 {video_name} 时发生异常: {e}")
self.state["failed"].append(video_name)
return False, video_name
def process_all_videos(self):
"""处理所有视频文件"""
if not self.video_files:
print("没有找到需要处理的视频文件")
return
print(f"找到 {len(self.video_files)} 个待处理视频文件")
print(f"输出目录: {self.output_dir}")
print(f"帧大小: {self.frame_size}x{self.frame_size}")
print(f"帧率: {self.fps} fps")
print(f"并发线程数: {self.num_workers}")
if self.state["completed"]:
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
# 记录开始时间
if self.state["start_time"] is None:
self.state["start_time"] = time.time()
# 创建进度条
with tqdm(total=len(self.video_files), desc="处理视频", unit="") as pbar:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# 提交所有任务
future_to_video = {
executor.submit(self._process_video, video_path): video_path
for video_path in self.video_files
}
# 处理完成的任务
for future in as_completed(future_to_video):
video_path = future_to_video[future]
try:
success, video_name = future.result()
if success:
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
else:
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
except Exception as e:
print(f"处理 {video_path.name} 时发生异常: {e}")
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
pbar.update(1)
self.state["total_processed"] += 1
# 定期保存状态
if self.state["total_processed"] % 5 == 0:
self._save_state()
# 最终保存状态
self._save_state()
# 打印处理结果
self._print_summary()
def _print_summary(self):
"""打印处理摘要"""
print("\n" + "="*50)
print("处理完成摘要:")
print(f"总处理视频数: {len(self.state['completed'])}")
print(f"失败视频数: {len(self.state['failed'])}")
if self.state["failed"]:
print("\n失败的视频:")
for video_name in self.state["failed"]:
print(f" - {video_name}")
if self.state["start_time"]:
elapsed_time = time.time() - self.state["start_time"]
print(f"\n总耗时: {elapsed_time:.2f}")
if self.state["total_processed"] > 0:
avg_time = elapsed_time / self.state["total_processed"]
print(f"平均每个视频: {avg_time:.2f}")
print("="*50)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="视频预处理脚本")
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
args = parser.parse_args()
# 检查输入目录
if not Path(args.input_dir).exists():
print(f"错误: 输入目录不存在: {args.input_dir}")
sys.exit(1)
# 检查FFmpeg是否可用
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("错误: FFmpeg未安装或不在PATH中")
sys.exit(1)
# 创建预处理器并开始处理
preprocessor = VideoPreprocessor(
input_dir=args.input_dir,
output_dir=args.output_dir,
frame_size=args.size,
fps=args.fps,
num_workers=args.workers,
quality=args.quality,
resume=not args.no_resume
)
try:
preprocessor.process_all_videos()
except KeyboardInterrupt:
print("\n\n用户中断处理,状态已保存")
preprocessor._save_state()
print("可以使用相同命令恢复处理")
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
preprocessor._save_state()
sys.exit(1)
if __name__ == "__main__":
main()