Compare commits

..

3 Commits

Author SHA1 Message Date
f7601e9170 初步可跑通,但loss计算有问题,不收敛 2026-01-08 09:43:23 +08:00
efd76bccd2 update .gitignore 2026-01-07 15:54:52 +08:00
4888619f9d iniit .gitignore 2026-01-07 15:54:20 +08:00
11 changed files with 657 additions and 62 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.vscode/
__pycache__/
venv/
runs/

57
dist_temporal_train.sh Executable file
View File

@@ -0,0 +1,57 @@
#!/usr/bin/env bash
# Distributed training script for SwiftFormerTemporal
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
DATA_PATH=$1
NUM_GPUS=$2
# Shift arguments to pass remaining options to python script
shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32}
EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"
echo "Data path: $DATA_PATH"
echo "Model: $MODEL"
echo "Batch size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "Output dir: $OUTPUT_DIR"
# Check if torch.distributed.launch or torchrun should be used
# For newer PyTorch versions (>=1.9), torchrun is recommended
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
echo "PyTorch version: $PYTHON_VERSION"
# Use torchrun for newer PyTorch versions
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
echo "Using torchrun (PyTorch >=1.10)"
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
else
echo "Using torch.distributed.launch"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
fi
# For single-node multi-GPU training with specific options:
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
echo "Training completed. Check logs in $OUTPUT_DIR"

View File

@@ -6,8 +6,10 @@ import datetime
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.scheduler import create_scheduler
@@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
try:
from tensorboardX import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser():
parser = argparse.ArgumentParser(
@@ -48,10 +61,48 @@ def get_args_parser():
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
@@ -90,26 +141,26 @@ def get_args_parser():
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
@@ -203,14 +254,18 @@ def main(args):
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Create loss function - simple MSE for Y channel prediction
class MSELossWrapper(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint
output_dir = Path(args.output_dir)
@@ -231,6 +286,21 @@ def main(args):
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
@@ -239,14 +309,18 @@ def main(args):
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
model_ema=model_ema, writer=writer,
global_step=global_step, args=args
)
lr_scheduler.step(epoch)
@@ -266,10 +340,10 @@ def main(args):
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
# Log stats to text file
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
@@ -284,18 +358,24 @@ def main(args):
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
@@ -305,7 +385,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
pred_frames, target_frames,
representations, temporal_indices
)
@@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None:
model_ema.update(model)
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
@@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device):
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

View File

@@ -6,9 +6,9 @@ import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
from timm.layers import DropPath, trunc_normal_
from timm.models import register_model
from timm.layers import to_2tuple
import einops
SwiftFormer_width = {

View File

@@ -7,7 +7,7 @@ from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.models.layers import DropPath, trunc_normal_
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
@@ -31,7 +31,7 @@ class DecoderBlock(nn.Module):
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False):
def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
@@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module):
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (4x upsampling total)
# stage2 to original resolution (now 8x upsampling total with stride 4)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1
kernel_size=3, stride=4, padding=1, output_padding=3
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
@@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
@@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
# Representation head for pose/velocity prediction
if use_representation_head:
@@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module):
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None)
pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)

26
multi_gpu_temporal_train.sh Executable file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env bash
# Simple multi-GPU training script for SwiftFormerTemporal
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
NUM_GPUS=${1:-2}
shift
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
# Set environment variables for distributed training
export MASTER_PORT=12345
export MASTER_ADDR=localhost
export WORLD_SIZE=$NUM_GPUS
# Launch training
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
--data-path "./videos" \
--model SwiftFormerTemporal_XS \
--batch-size 32 \
--epochs 100 \
--lr 1e-3 \
--output-dir "./temporal_output_multi" \
--num-workers 8 \
--pin-mem \
"$@"

0
temporal_train.sh Normal file
View File

45
test_cuda.py Normal file
View File

@@ -0,0 +1,45 @@
import torch
def test_cuda_availability():
"""全面测试CUDA可用性"""
print("="*50)
print("PyTorch CUDA 测试")
print("="*50)
# 基本信息
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
print("CUDA不可用可能原因")
print("1. 未安装CUDA驱动")
print("2. 安装的是CPU版本的PyTorch")
print("3. CUDA版本与PyTorch不匹配")
return False
# 设备信息
device_count = torch.cuda.device_count()
print(f"发现 {device_count} 个CUDA设备")
for i in range(device_count):
print(f"\n设备 {i}:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
# 简单张量测试
print("\n运行CUDA测试...")
try:
x = torch.randn(3, 3).cuda()
y = torch.randn(3, 3).cuda()
z = x + y
print("CUDA计算测试: 成功!")
print(f"设备上的张量形状: {z.shape}")
return True
except Exception as e:
print(f"CUDA计算测试失败: {e}")
return False
if __name__ == "__main__":
test_cuda_availability()

33
test_import.py Normal file
View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
"""
测试 timm 导入是否正常工作
"""
import sys
print("Python version:", sys.version)
try:
from timm.layers import to_2tuple, DropPath, trunc_normal_
from timm.models import register_model
print("✓ 成功导入 timm.layers.to_2tuple")
print("✓ 成功导入 timm.layers.DropPath")
print("✓ 成功导入 timm.layers.trunc_normal_")
print("✓ 成功导入 timm.models.register_model")
except ImportError as e:
print(f"✗ 导入失败: {e}")
sys.exit(1)
try:
from models.swiftformer import SwiftFormer_XS
print("✓ 成功导入 SwiftFormer_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
sys.exit(1)
try:
from models.swiftformer_temporal import SwiftFormerTemporal_XS
print("✓ 成功导入 SwiftFormerTemporal_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
sys.exit(1)
print("\n✅ 所有导入测试通过!")

View File

@@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset):
else:
self.transform = transform
# Normalization (ImageNet stats)
# Normalization for Y channel (single channel)
# Compute average of ImageNet RGB means and stds
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
mean=[y_mean],
std=[y_std]
)
def _default_transform(self):
@@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
@@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset):
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
# Convert to tensors, normalize, and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
tensor = transforms.ToTensor()(frame) # [3, H, W]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray)
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset):
@@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset):
self.frame_size = frame_size
self.is_train = is_train
# Normalization
# Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
mean=[y_mean],
std=[y_std]
)
def __len__(self):

303
video_preprocessor.py Normal file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
视频预处理脚本 - 将MP4视频转换为224x224帧图像
支持多线程并发处理、进度条显示和中断恢复功能
"""
import os
import sys
import json
import argparse
import subprocess
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from typing import List, Dict, Optional
class VideoPreprocessor:
"""视频预处理器,支持多线程和中断恢复"""
def __init__(self,
input_dir: str,
output_dir: str,
frame_size: int = 224,
fps: int = 30,
num_workers: int = 4,
quality: int = 2,
resume: bool = True):
"""
初始化预处理器
Args:
input_dir: 输入视频目录
output_dir: 输出帧目录
frame_size: 帧大小(正方形)
fps: 提取帧率
num_workers: 并发工作线程数
quality: JPEG质量 (1-31, 数值越小质量越高)
resume: 是否启用中断恢复
"""
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.frame_size = frame_size
self.fps = fps
self.num_workers = num_workers
self.quality = quality
self.resume = resume
# 状态文件路径
self.state_file = self.output_dir / ".preprocessing_state.json"
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化状态
self.state = self._load_state()
# 收集所有视频文件
self.video_files = self._collect_video_files()
def _load_state(self) -> Dict:
"""加载处理状态"""
if self.resume and self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
print(f"警告: 无法读取状态文件,将重新开始处理")
return {
"completed": [],
"failed": [],
"total_processed": 0,
"start_time": None,
"last_update": None
}
def _save_state(self):
"""保存处理状态"""
self.state["last_update"] = time.time()
try:
with open(self.state_file, 'w') as f:
json.dump(self.state, f, indent=2)
except IOError as e:
print(f"警告: 无法保存状态文件: {e}")
def _collect_video_files(self) -> List[Path]:
"""收集所有需要处理的视频文件"""
video_files = []
for file_path in self.input_dir.glob("*.mp4"):
if file_path.name not in self.state["completed"]:
video_files.append(file_path)
return sorted(video_files)
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
"""解析视频文件名使用完整文件名作为ID"""
name_without_ext = video_path.stem
# 直接使用完整文件名作为ID确保每个mp4文件有独立的输出目录
return {
"video_id": name_without_ext,
"start_frame": "unknown",
"end_frame": "unknown",
"full_name": name_without_ext
}
def _extract_frames(self, video_path: Path) -> bool:
"""提取单个视频的帧"""
try:
# 解析视频名称
video_info = self._parse_video_name(video_path)
output_subdir = self.output_dir / video_info["video_id"]
output_subdir.mkdir(exist_ok=True)
# 构建FFmpeg命令
output_pattern = output_subdir / "frame_%04d.jpg"
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
"-q:v", str(self.quality),
"-y", # 覆盖输出文件
str(output_pattern)
]
# 执行FFmpeg命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
return False
# 验证输出帧数量
output_frames = list(output_subdir.glob("frame_*.jpg"))
if len(output_frames) == 0:
print(f"警告: {video_path.name} 没有生成任何帧")
return False
return True
except subprocess.TimeoutExpired:
print(f"超时处理 {video_path.name}")
return False
except Exception as e:
print(f"处理 {video_path.name} 时发生错误: {e}")
return False
def _process_video(self, video_path: Path) -> tuple[bool, str]:
"""处理单个视频文件"""
video_name = video_path.name
try:
success = self._extract_frames(video_path)
if success:
self.state["completed"].append(video_name)
if video_name in self.state["failed"]:
self.state["failed"].remove(video_name)
return True, video_name
else:
self.state["failed"].append(video_name)
return False, video_name
except Exception as e:
print(f"处理 {video_name} 时发生异常: {e}")
self.state["failed"].append(video_name)
return False, video_name
def process_all_videos(self):
"""处理所有视频文件"""
if not self.video_files:
print("没有找到需要处理的视频文件")
return
print(f"找到 {len(self.video_files)} 个待处理视频文件")
print(f"输出目录: {self.output_dir}")
print(f"帧大小: {self.frame_size}x{self.frame_size}")
print(f"帧率: {self.fps} fps")
print(f"并发线程数: {self.num_workers}")
if self.state["completed"]:
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
# 记录开始时间
if self.state["start_time"] is None:
self.state["start_time"] = time.time()
# 创建进度条
with tqdm(total=len(self.video_files), desc="处理视频", unit="") as pbar:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# 提交所有任务
future_to_video = {
executor.submit(self._process_video, video_path): video_path
for video_path in self.video_files
}
# 处理完成的任务
for future in as_completed(future_to_video):
video_path = future_to_video[future]
try:
success, video_name = future.result()
if success:
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
else:
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
except Exception as e:
print(f"处理 {video_path.name} 时发生异常: {e}")
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
pbar.update(1)
self.state["total_processed"] += 1
# 定期保存状态
if self.state["total_processed"] % 5 == 0:
self._save_state()
# 最终保存状态
self._save_state()
# 打印处理结果
self._print_summary()
def _print_summary(self):
"""打印处理摘要"""
print("\n" + "="*50)
print("处理完成摘要:")
print(f"总处理视频数: {len(self.state['completed'])}")
print(f"失败视频数: {len(self.state['failed'])}")
if self.state["failed"]:
print("\n失败的视频:")
for video_name in self.state["failed"]:
print(f" - {video_name}")
if self.state["start_time"]:
elapsed_time = time.time() - self.state["start_time"]
print(f"\n总耗时: {elapsed_time:.2f}")
if self.state["total_processed"] > 0:
avg_time = elapsed_time / self.state["total_processed"]
print(f"平均每个视频: {avg_time:.2f}")
print("="*50)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="视频预处理脚本")
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
args = parser.parse_args()
# 检查输入目录
if not Path(args.input_dir).exists():
print(f"错误: 输入目录不存在: {args.input_dir}")
sys.exit(1)
# 检查FFmpeg是否可用
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("错误: FFmpeg未安装或不在PATH中")
sys.exit(1)
# 创建预处理器并开始处理
preprocessor = VideoPreprocessor(
input_dir=args.input_dir,
output_dir=args.output_dir,
frame_size=args.size,
fps=args.fps,
num_workers=args.workers,
quality=args.quality,
resume=not args.no_resume
)
try:
preprocessor.process_all_videos()
except KeyboardInterrupt:
print("\n\n用户中断处理,状态已保存")
preprocessor._save_state()
print("可以使用相同命令恢复处理")
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
preprocessor._save_state()
sys.exit(1)
if __name__ == "__main__":
main()