初步可跑通,但loss计算有问题,不收敛
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
|
.vscode/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
venv/
|
venv/
|
||||||
|
runs/
|
||||||
57
dist_temporal_train.sh
Executable file
57
dist_temporal_train.sh
Executable file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Distributed training script for SwiftFormerTemporal
|
||||||
|
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
|
||||||
|
|
||||||
|
DATA_PATH=$1
|
||||||
|
NUM_GPUS=$2
|
||||||
|
|
||||||
|
# Shift arguments to pass remaining options to python script
|
||||||
|
shift 2
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
||||||
|
BATCH_SIZE=${BATCH_SIZE:-32}
|
||||||
|
EPOCHS=${EPOCHS:-100}
|
||||||
|
LR=${LR:-1e-3}
|
||||||
|
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
||||||
|
|
||||||
|
echo "Starting distributed training with $NUM_GPUS GPUs"
|
||||||
|
echo "Data path: $DATA_PATH"
|
||||||
|
echo "Model: $MODEL"
|
||||||
|
echo "Batch size: $BATCH_SIZE"
|
||||||
|
echo "Epochs: $EPOCHS"
|
||||||
|
echo "Output dir: $OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Check if torch.distributed.launch or torchrun should be used
|
||||||
|
# For newer PyTorch versions (>=1.9), torchrun is recommended
|
||||||
|
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
|
||||||
|
echo "PyTorch version: $PYTHON_VERSION"
|
||||||
|
|
||||||
|
# Use torchrun for newer PyTorch versions
|
||||||
|
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
|
||||||
|
echo "Using torchrun (PyTorch >=1.10)"
|
||||||
|
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
|
||||||
|
--data-path "$DATA_PATH" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--batch-size $BATCH_SIZE \
|
||||||
|
--epochs $EPOCHS \
|
||||||
|
--lr $LR \
|
||||||
|
--output-dir "$OUTPUT_DIR" \
|
||||||
|
"$@"
|
||||||
|
else
|
||||||
|
echo "Using torch.distributed.launch"
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
|
||||||
|
--data-path "$DATA_PATH" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--batch-size $BATCH_SIZE \
|
||||||
|
--epochs $EPOCHS \
|
||||||
|
--lr $LR \
|
||||||
|
--output-dir "$OUTPUT_DIR" \
|
||||||
|
"$@"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For single-node multi-GPU training with specific options:
|
||||||
|
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
|
||||||
|
|
||||||
|
echo "Training completed. Check logs in $OUTPUT_DIR"
|
||||||
190
main_temporal.py
190
main_temporal.py
@@ -6,8 +6,10 @@ import datetime
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
@@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo
|
|||||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
||||||
from util.frame_losses import MultiTaskLoss
|
from util.frame_losses import MultiTaskLoss
|
||||||
|
|
||||||
|
# Try to import TensorBoard
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
TENSORBOARD_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
TENSORBOARD_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TENSORBOARD_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
def get_args_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -48,10 +61,48 @@ def get_args_parser():
|
|||||||
# Training parameters
|
# Training parameters
|
||||||
parser.add_argument('--batch-size', default=32, type=int)
|
parser.add_argument('--batch-size', default=32, type=int)
|
||||||
parser.add_argument('--epochs', default=100, type=int)
|
parser.add_argument('--epochs', default=100, type=int)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
|
||||||
help='learning rate (default: 1e-3)')
|
# Optimizer parameters (required by timm's create_optimizer)
|
||||||
|
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
||||||
|
help='Optimizer (default: "adamw"')
|
||||||
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||||
|
help='Optimizer Epsilon (default: 1e-8)')
|
||||||
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||||
|
help='Optimizer Betas (default: None, use opt default)')
|
||||||
|
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
|
||||||
|
help='Clip gradient norm (default: None, no clipping)')
|
||||||
|
parser.add_argument('--clip-mode', type=str, default='agc',
|
||||||
|
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||||
|
help='SGD momentum (default: 0.9)')
|
||||||
parser.add_argument('--weight-decay', type=float, default=0.05,
|
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||||
help='weight decay (default: 0.05)')
|
help='weight decay (default: 0.05)')
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
||||||
|
help='learning rate (default: 1e-3)')
|
||||||
|
|
||||||
|
# Learning rate schedule parameters (required by timm's create_scheduler)
|
||||||
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||||
|
help='LR scheduler (default: "cosine"')
|
||||||
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||||
|
help='learning rate noise on/off epoch percentages')
|
||||||
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||||
|
help='learning rate noise limit percent (default: 0.67)')
|
||||||
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||||
|
help='learning rate noise std-dev (default: 1.0)')
|
||||||
|
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||||
|
help='warmup learning rate (default: 1e-6)')
|
||||||
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||||
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||||
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||||
|
help='epoch interval to decay LR')
|
||||||
|
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||||
|
help='epochs to warmup LR, if scheduler supports')
|
||||||
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||||
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||||
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||||
|
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||||
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||||
|
help='LR decay rate (default: 0.1)')
|
||||||
|
|
||||||
# Loss parameters
|
# Loss parameters
|
||||||
parser.add_argument('--frame-weight', type=float, default=1.0,
|
parser.add_argument('--frame-weight', type=float, default=1.0,
|
||||||
@@ -90,26 +141,26 @@ def get_args_parser():
|
|||||||
parser.add_argument('--dist-url', default='env://',
|
parser.add_argument('--dist-url', default='env://',
|
||||||
help='url used to set up distributed training')
|
help='url used to set up distributed training')
|
||||||
|
|
||||||
|
# TensorBoard logging
|
||||||
|
parser.add_argument('--tensorboard-logdir', default='./runs',
|
||||||
|
type=str, help='TensorBoard log directory')
|
||||||
|
parser.add_argument('--log-images', action='store_true',
|
||||||
|
help='Log sample images to TensorBoard')
|
||||||
|
parser.add_argument('--image-log-freq', default=100, type=int,
|
||||||
|
help='Frequency of logging images (in iterations)')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(is_train, args):
|
def build_dataset(is_train, args):
|
||||||
"""Build video frame dataset"""
|
"""Build video frame dataset"""
|
||||||
if args.dataset_type == 'synthetic':
|
dataset = VideoFrameDataset(
|
||||||
dataset = SyntheticVideoDataset(
|
root_dir=args.data_path,
|
||||||
num_samples=1000 if is_train else 200,
|
num_frames=args.num_frames,
|
||||||
num_frames=args.num_frames,
|
frame_size=args.frame_size,
|
||||||
frame_size=args.frame_size,
|
is_train=is_train,
|
||||||
is_train=is_train
|
max_interval=args.max_interval
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
dataset = VideoFrameDataset(
|
|
||||||
root_dir=args.data_path,
|
|
||||||
num_frames=args.num_frames,
|
|
||||||
frame_size=args.frame_size,
|
|
||||||
is_train=is_train,
|
|
||||||
max_interval=args.max_interval
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
@@ -203,14 +254,18 @@ def main(args):
|
|||||||
# Create scheduler
|
# Create scheduler
|
||||||
lr_scheduler, _ = create_scheduler(args, optimizer)
|
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||||
|
|
||||||
# Create loss function
|
# Create loss function - simple MSE for Y channel prediction
|
||||||
criterion = MultiTaskLoss(
|
class MSELossWrapper(nn.Module):
|
||||||
frame_weight=args.frame_weight,
|
def __init__(self):
|
||||||
contrastive_weight=args.contrastive_weight,
|
super().__init__()
|
||||||
l1_weight=args.l1_weight,
|
self.mse = nn.MSELoss()
|
||||||
ssim_weight=args.ssim_weight,
|
|
||||||
use_contrastive=not args.no_contrastive
|
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
||||||
)
|
loss = self.mse(pred_frame, target_frame)
|
||||||
|
loss_dict = {'mse': loss}
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
|
criterion = MSELossWrapper()
|
||||||
|
|
||||||
# Resume from checkpoint
|
# Resume from checkpoint
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
@@ -231,6 +286,21 @@ def main(args):
|
|||||||
if 'scaler' in checkpoint:
|
if 'scaler' in checkpoint:
|
||||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
writer = None
|
||||||
|
if TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||||
|
from datetime import datetime
|
||||||
|
# Create log directory with timestamp
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
writer = SummaryWriter(log_dir=log_dir)
|
||||||
|
print(f"TensorBoard logs will be saved to: {log_dir}")
|
||||||
|
print(f"To view logs, run: tensorboard --logdir={log_dir}")
|
||||||
|
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||||
|
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
|
||||||
|
print("Training will continue without TensorBoard logging.")
|
||||||
|
|
||||||
if args.eval:
|
if args.eval:
|
||||||
test_stats = evaluate(data_loader_val, model, criterion, device)
|
test_stats = evaluate(data_loader_val, model, criterion, device)
|
||||||
print(f"Test stats: {test_stats}")
|
print(f"Test stats: {test_stats}")
|
||||||
@@ -239,14 +309,18 @@ def main(args):
|
|||||||
print(f"Start training for {args.epochs} epochs")
|
print(f"Start training for {args.epochs} epochs")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Global step counter for TensorBoard
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
data_loader_train.sampler.set_epoch(epoch)
|
data_loader_train.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_stats = train_one_epoch(
|
train_stats, global_step = train_one_epoch(
|
||||||
model, criterion, data_loader_train,
|
model, criterion, data_loader_train,
|
||||||
optimizer, device, epoch, loss_scaler,
|
optimizer, device, epoch, loss_scaler,
|
||||||
model_ema=model_ema
|
model_ema=model_ema, writer=writer,
|
||||||
|
global_step=global_step, args=args
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler.step(epoch)
|
lr_scheduler.step(epoch)
|
||||||
@@ -266,10 +340,10 @@ def main(args):
|
|||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
if epoch % 5 == 0 or epoch == args.epochs - 1:
|
if epoch % 5 == 0 or epoch == args.epochs - 1:
|
||||||
test_stats = evaluate(data_loader_val, model, criterion, device)
|
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
|
||||||
print(f"Epoch {epoch}: Test stats: {test_stats}")
|
print(f"Epoch {epoch}: Test stats: {test_stats}")
|
||||||
|
|
||||||
# Log stats
|
# Log stats to text file
|
||||||
log_stats = {
|
log_stats = {
|
||||||
**{f'train_{k}': v for k, v in train_stats.items()},
|
**{f'train_{k}': v for k, v in train_stats.items()},
|
||||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||||
@@ -284,18 +358,24 @@ def main(args):
|
|||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
print(f'Training time {total_time_str}')
|
print(f'Training time {total_time_str}')
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
if writer is not None:
|
||||||
|
writer.close()
|
||||||
|
print(f"TensorBoard logs saved to: {writer.log_dir}")
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||||
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
|
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
|
||||||
|
global_step=0, args=None, **kwargs):
|
||||||
model.train()
|
model.train()
|
||||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||||
header = f'Epoch: [{epoch}]'
|
header = f'Epoch: [{epoch}]'
|
||||||
print_freq = 10
|
print_freq = 10
|
||||||
|
|
||||||
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
|
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
|
||||||
data_loader, print_freq, header):
|
metric_logger.log_every(data_loader, print_freq, header)):
|
||||||
|
|
||||||
input_frames = input_frames.to(device, non_blocking=True)
|
input_frames = input_frames.to(device, non_blocking=True)
|
||||||
target_frames = target_frames.to(device, non_blocking=True)
|
target_frames = target_frames.to(device, non_blocking=True)
|
||||||
@@ -305,7 +385,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
with torch.cuda.amp.autocast():
|
with torch.cuda.amp.autocast():
|
||||||
pred_frames, representations = model(input_frames)
|
pred_frames, representations = model(input_frames)
|
||||||
loss, loss_dict = criterion(
|
loss, loss_dict = criterion(
|
||||||
pred_frames, target_frames,
|
pred_frames, target_frames,
|
||||||
representations, temporal_indices
|
representations, temporal_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
model_ema.update(model)
|
model_ema.update(model)
|
||||||
|
|
||||||
|
# Log to TensorBoard
|
||||||
|
if writer is not None:
|
||||||
|
# Log scalar metrics every iteration
|
||||||
|
writer.add_scalar('train/loss', loss_value, global_step)
|
||||||
|
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
|
||||||
|
|
||||||
|
# Log individual loss components
|
||||||
|
for k, v in loss_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
writer.add_scalar(f'train/{k}', v.item(), global_step)
|
||||||
|
else:
|
||||||
|
writer.add_scalar(f'train/{k}', v, global_step)
|
||||||
|
|
||||||
|
# Log images periodically
|
||||||
|
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Take first sample from batch for visualization
|
||||||
|
pred_vis, _ = model(input_frames[:1])
|
||||||
|
# Convert to appropriate format for TensorBoard
|
||||||
|
# Assuming frames are in [B, C, H, W] format
|
||||||
|
writer.add_images('train/input', input_frames[:1], global_step)
|
||||||
|
writer.add_images('train/target', target_frames[:1], global_step)
|
||||||
|
writer.add_images('train/predicted', pred_vis[:1], global_step)
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
metric_logger.update(loss=loss_value)
|
metric_logger.update(loss=loss_value)
|
||||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||||
for k, v in loss_dict.items():
|
for k, v in loss_dict.items():
|
||||||
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
|
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
metric_logger.synchronize_between_processes()
|
metric_logger.synchronize_between_processes()
|
||||||
print("Averaged stats:", metric_logger)
|
print("Averaged stats:", metric_logger)
|
||||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
|
||||||
|
# Log epoch-level metrics
|
||||||
|
if writer is not None:
|
||||||
|
for k, meter in metric_logger.meters.items():
|
||||||
|
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
|
||||||
|
|
||||||
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate(data_loader, model, criterion, device):
|
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||||
model.eval()
|
model.eval()
|
||||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
header = 'Test:'
|
header = 'Test:'
|
||||||
@@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device):
|
|||||||
|
|
||||||
metric_logger.synchronize_between_processes()
|
metric_logger.synchronize_between_processes()
|
||||||
print('* Test stats:', metric_logger)
|
print('* Test stats:', metric_logger)
|
||||||
|
|
||||||
|
# Log validation metrics to TensorBoard
|
||||||
|
if writer is not None:
|
||||||
|
for k, meter in metric_logger.meters.items():
|
||||||
|
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
|
||||||
|
|
||||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.models.layers import DropPath, trunc_normal_
|
from timm.layers import DropPath, trunc_normal_
|
||||||
from timm.models.registry import register_model
|
from timm.models import register_model
|
||||||
from timm.models.layers.helpers import to_2tuple
|
from timm.layers import to_2tuple
|
||||||
import einops
|
import einops
|
||||||
|
|
||||||
SwiftFormer_width = {
|
SwiftFormer_width = {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from .swiftformer import (
|
|||||||
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
|
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
|
||||||
stem, Embedding, Stage
|
stem, Embedding, Stage
|
||||||
)
|
)
|
||||||
from timm.models.layers import DropPath, trunc_normal_
|
from timm.layers import DropPath, trunc_normal_
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
@@ -31,7 +31,7 @@ class DecoderBlock(nn.Module):
|
|||||||
|
|
||||||
class FramePredictionDecoder(nn.Module):
|
class FramePredictionDecoder(nn.Module):
|
||||||
"""Lightweight decoder for frame prediction with optional skip connections"""
|
"""Lightweight decoder for frame prediction with optional skip connections"""
|
||||||
def __init__(self, embed_dims, output_channels=3, use_skip=False):
|
def __init__(self, embed_dims, output_channels=1, use_skip=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_skip = use_skip
|
self.use_skip = use_skip
|
||||||
# Reverse the embed_dims for decoder
|
# Reverse the embed_dims for decoder
|
||||||
@@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module):
|
|||||||
decoder_dims[2], decoder_dims[3],
|
decoder_dims[2], decoder_dims[3],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# stage2 to original resolution (4x upsampling total)
|
# stage2 to original resolution (now 8x upsampling total with stride 4)
|
||||||
self.blocks.append(nn.Sequential(
|
self.blocks.append(nn.Sequential(
|
||||||
nn.ConvTranspose2d(
|
nn.ConvTranspose2d(
|
||||||
decoder_dims[3], 32,
|
decoder_dims[3], 32,
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=4, padding=1, output_padding=3
|
||||||
),
|
),
|
||||||
nn.BatchNorm2d(32),
|
nn.BatchNorm2d(32),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
@@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
"""
|
"""
|
||||||
SwiftFormer with temporal input for frame prediction.
|
SwiftFormer with temporal input for frame prediction.
|
||||||
Input: [B, num_frames, H, W] (Y channel only)
|
Input: [B, num_frames, H, W] (Y channel only)
|
||||||
Output: predicted frame [B, 3, H, W] and optional representation
|
Output: predicted frame [B, 1, H, W] and optional representation
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name='XS',
|
model_name='XS',
|
||||||
@@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
|
|
||||||
# Frame prediction decoder
|
# Frame prediction decoder
|
||||||
if use_decoder:
|
if use_decoder:
|
||||||
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
|
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
|
||||||
|
|
||||||
# Representation head for pose/velocity prediction
|
# Representation head for pose/velocity prediction
|
||||||
if use_representation_head:
|
if use_representation_head:
|
||||||
@@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
x: input frames of shape [B, num_frames, H, W]
|
x: input frames of shape [B, num_frames, H, W]
|
||||||
Returns:
|
Returns:
|
||||||
If return_features is False:
|
If return_features is False:
|
||||||
pred_frame: predicted frame [B, 3, H, W] (or None)
|
pred_frame: predicted frame [B, 1, H, W] (or None)
|
||||||
representation: optional representation vector [B, representation_dim] (or None)
|
representation: optional representation vector [B, representation_dim] (or None)
|
||||||
If return_features is True:
|
If return_features is True:
|
||||||
pred_frame, representation, features (list of stage features)
|
pred_frame, representation, features (list of stage features)
|
||||||
|
|||||||
26
multi_gpu_temporal_train.sh
Executable file
26
multi_gpu_temporal_train.sh
Executable file
@@ -0,0 +1,26 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Simple multi-GPU training script for SwiftFormerTemporal
|
||||||
|
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
|
||||||
|
|
||||||
|
NUM_GPUS=${1:-2}
|
||||||
|
shift
|
||||||
|
|
||||||
|
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
|
||||||
|
|
||||||
|
# Set environment variables for distributed training
|
||||||
|
export MASTER_PORT=12345
|
||||||
|
export MASTER_ADDR=localhost
|
||||||
|
export WORLD_SIZE=$NUM_GPUS
|
||||||
|
|
||||||
|
# Launch training
|
||||||
|
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
|
||||||
|
--data-path "./videos" \
|
||||||
|
--model SwiftFormerTemporal_XS \
|
||||||
|
--batch-size 32 \
|
||||||
|
--epochs 100 \
|
||||||
|
--lr 1e-3 \
|
||||||
|
--output-dir "./temporal_output_multi" \
|
||||||
|
--num-workers 8 \
|
||||||
|
--pin-mem \
|
||||||
|
"$@"
|
||||||
0
temporal_train.sh
Normal file
0
temporal_train.sh
Normal file
45
test_cuda.py
Normal file
45
test_cuda.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def test_cuda_availability():
|
||||||
|
"""全面测试CUDA可用性"""
|
||||||
|
|
||||||
|
print("="*50)
|
||||||
|
print("PyTorch CUDA 测试")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
print(f"PyTorch版本: {torch.__version__}")
|
||||||
|
print(f"CUDA可用: {torch.cuda.is_available()}")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA不可用,可能原因:")
|
||||||
|
print("1. 未安装CUDA驱动")
|
||||||
|
print("2. 安装的是CPU版本的PyTorch")
|
||||||
|
print("3. CUDA版本与PyTorch不匹配")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 设备信息
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
print(f"发现 {device_count} 个CUDA设备")
|
||||||
|
|
||||||
|
for i in range(device_count):
|
||||||
|
print(f"\n设备 {i}:")
|
||||||
|
print(f" 名称: {torch.cuda.get_device_name(i)}")
|
||||||
|
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
|
||||||
|
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
|
||||||
|
|
||||||
|
# 简单张量测试
|
||||||
|
print("\n运行CUDA测试...")
|
||||||
|
try:
|
||||||
|
x = torch.randn(3, 3).cuda()
|
||||||
|
y = torch.randn(3, 3).cuda()
|
||||||
|
z = x + y
|
||||||
|
print("CUDA计算测试: 成功!")
|
||||||
|
print(f"设备上的张量形状: {z.shape}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CUDA计算测试失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_cuda_availability()
|
||||||
33
test_import.py
Normal file
33
test_import.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试 timm 导入是否正常工作
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
print("Python version:", sys.version)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from timm.layers import to_2tuple, DropPath, trunc_normal_
|
||||||
|
from timm.models import register_model
|
||||||
|
print("✓ 成功导入 timm.layers.to_2tuple")
|
||||||
|
print("✓ 成功导入 timm.layers.DropPath")
|
||||||
|
print("✓ 成功导入 timm.layers.trunc_normal_")
|
||||||
|
print("✓ 成功导入 timm.models.register_model")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"✗ 导入失败: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from models.swiftformer import SwiftFormer_XS
|
||||||
|
print("✓ 成功导入 SwiftFormer_XS")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from models.swiftformer_temporal import SwiftFormerTemporal_XS
|
||||||
|
print("✓ 成功导入 SwiftFormerTemporal_XS")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("\n✅ 所有导入测试通过!")
|
||||||
@@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
# Normalization (ImageNet stats)
|
# Normalization for Y channel (single channel)
|
||||||
|
# Compute average of ImageNet RGB means and stds
|
||||||
|
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||||
|
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||||
self.normalize = transforms.Normalize(
|
self.normalize = transforms.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406],
|
mean=[y_mean],
|
||||||
std=[0.229, 0.224, 0.225]
|
std=[y_std]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _default_transform(self):
|
def _default_transform(self):
|
||||||
@@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset):
|
|||||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
input_frames: [3 * num_frames, H, W] concatenated input frames
|
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
|
||||||
target_frame: [3, H, W] target frame to predict
|
target_frame: [1, H, W] target frame to predict (Y channel only)
|
||||||
temporal_idx: temporal index of target frame (for contrastive loss)
|
temporal_idx: temporal index of target frame (for contrastive loss)
|
||||||
"""
|
"""
|
||||||
video_idx, start_idx = self.frame_indices[idx]
|
video_idx, start_idx = self.frame_indices[idx]
|
||||||
@@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset):
|
|||||||
if self.transform:
|
if self.transform:
|
||||||
target_frame = self.transform(target_frame)
|
target_frame = self.transform(target_frame)
|
||||||
|
|
||||||
# Convert to tensors and normalize
|
# Convert to tensors, normalize, and convert to grayscale (Y channel)
|
||||||
input_tensors = []
|
input_tensors = []
|
||||||
for frame in input_frames:
|
for frame in input_frames:
|
||||||
tensor = transforms.ToTensor()(frame)
|
tensor = transforms.ToTensor()(frame) # [3, H, W]
|
||||||
tensor = self.normalize(tensor)
|
# Convert RGB to grayscale using weighted sum
|
||||||
input_tensors.append(tensor)
|
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||||
|
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
|
||||||
|
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
|
||||||
|
input_tensors.append(gray)
|
||||||
|
|
||||||
target_tensor = transforms.ToTensor()(target_frame)
|
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
|
||||||
target_tensor = self.normalize(target_tensor)
|
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||||
|
target_gray = self.normalize(target_gray)
|
||||||
|
|
||||||
# Concatenate input frames along channel dimension
|
# Concatenate input frames along channel dimension
|
||||||
input_concatenated = torch.cat(input_tensors, dim=0)
|
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||||
|
|
||||||
# Temporal index (for contrastive loss)
|
# Temporal index (for contrastive loss)
|
||||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||||
|
|
||||||
return input_concatenated, target_tensor, temporal_idx
|
return input_concatenated, target_gray, temporal_idx
|
||||||
|
|
||||||
|
|
||||||
class SyntheticVideoDataset(Dataset):
|
class SyntheticVideoDataset(Dataset):
|
||||||
@@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset):
|
|||||||
self.frame_size = frame_size
|
self.frame_size = frame_size
|
||||||
self.is_train = is_train
|
self.is_train = is_train
|
||||||
|
|
||||||
# Normalization
|
# Normalization for Y channel (single channel)
|
||||||
|
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||||
|
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||||
self.normalize = transforms.Normalize(
|
self.normalize = transforms.Normalize(
|
||||||
mean=[0.485, 0.456, 0.406],
|
mean=[y_mean],
|
||||||
std=[0.229, 0.224, 0.225]
|
std=[y_std]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
|||||||
303
video_preprocessor.py
Normal file
303
video_preprocessor.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
视频预处理脚本 - 将MP4视频转换为224x224帧图像
|
||||||
|
支持多线程并发处理、进度条显示和中断恢复功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPreprocessor:
|
||||||
|
"""视频预处理器,支持多线程和中断恢复"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
input_dir: str,
|
||||||
|
output_dir: str,
|
||||||
|
frame_size: int = 224,
|
||||||
|
fps: int = 30,
|
||||||
|
num_workers: int = 4,
|
||||||
|
quality: int = 2,
|
||||||
|
resume: bool = True):
|
||||||
|
"""
|
||||||
|
初始化预处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: 输入视频目录
|
||||||
|
output_dir: 输出帧目录
|
||||||
|
frame_size: 帧大小(正方形)
|
||||||
|
fps: 提取帧率
|
||||||
|
num_workers: 并发工作线程数
|
||||||
|
quality: JPEG质量 (1-31, 数值越小质量越高)
|
||||||
|
resume: 是否启用中断恢复
|
||||||
|
"""
|
||||||
|
self.input_dir = Path(input_dir)
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.frame_size = frame_size
|
||||||
|
self.fps = fps
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.quality = quality
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
# 状态文件路径
|
||||||
|
self.state_file = self.output_dir / ".preprocessing_state.json"
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 初始化状态
|
||||||
|
self.state = self._load_state()
|
||||||
|
|
||||||
|
# 收集所有视频文件
|
||||||
|
self.video_files = self._collect_video_files()
|
||||||
|
|
||||||
|
def _load_state(self) -> Dict:
|
||||||
|
"""加载处理状态"""
|
||||||
|
if self.resume and self.state_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.state_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
print(f"警告: 无法读取状态文件,将重新开始处理")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"completed": [],
|
||||||
|
"failed": [],
|
||||||
|
"total_processed": 0,
|
||||||
|
"start_time": None,
|
||||||
|
"last_update": None
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_state(self):
|
||||||
|
"""保存处理状态"""
|
||||||
|
self.state["last_update"] = time.time()
|
||||||
|
try:
|
||||||
|
with open(self.state_file, 'w') as f:
|
||||||
|
json.dump(self.state, f, indent=2)
|
||||||
|
except IOError as e:
|
||||||
|
print(f"警告: 无法保存状态文件: {e}")
|
||||||
|
|
||||||
|
def _collect_video_files(self) -> List[Path]:
|
||||||
|
"""收集所有需要处理的视频文件"""
|
||||||
|
video_files = []
|
||||||
|
for file_path in self.input_dir.glob("*.mp4"):
|
||||||
|
if file_path.name not in self.state["completed"]:
|
||||||
|
video_files.append(file_path)
|
||||||
|
|
||||||
|
return sorted(video_files)
|
||||||
|
|
||||||
|
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
|
||||||
|
"""解析视频文件名,使用完整文件名作为ID"""
|
||||||
|
name_without_ext = video_path.stem
|
||||||
|
|
||||||
|
# 直接使用完整文件名作为ID,确保每个mp4文件有独立的输出目录
|
||||||
|
return {
|
||||||
|
"video_id": name_without_ext,
|
||||||
|
"start_frame": "unknown",
|
||||||
|
"end_frame": "unknown",
|
||||||
|
"full_name": name_without_ext
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_frames(self, video_path: Path) -> bool:
|
||||||
|
"""提取单个视频的帧"""
|
||||||
|
try:
|
||||||
|
# 解析视频名称
|
||||||
|
video_info = self._parse_video_name(video_path)
|
||||||
|
output_subdir = self.output_dir / video_info["video_id"]
|
||||||
|
output_subdir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 构建FFmpeg命令
|
||||||
|
output_pattern = output_subdir / "frame_%04d.jpg"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
|
||||||
|
"-q:v", str(self.quality),
|
||||||
|
"-y", # 覆盖输出文件
|
||||||
|
str(output_pattern)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 执行FFmpeg命令
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300 # 5分钟超时
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 验证输出帧数量
|
||||||
|
output_frames = list(output_subdir.glob("frame_*.jpg"))
|
||||||
|
if len(output_frames) == 0:
|
||||||
|
print(f"警告: {video_path.name} 没有生成任何帧")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"超时处理 {video_path.name}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_path.name} 时发生错误: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _process_video(self, video_path: Path) -> tuple[bool, str]:
|
||||||
|
"""处理单个视频文件"""
|
||||||
|
video_name = video_path.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = self._extract_frames(video_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
self.state["completed"].append(video_name)
|
||||||
|
if video_name in self.state["failed"]:
|
||||||
|
self.state["failed"].remove(video_name)
|
||||||
|
return True, video_name
|
||||||
|
else:
|
||||||
|
self.state["failed"].append(video_name)
|
||||||
|
return False, video_name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_name} 时发生异常: {e}")
|
||||||
|
self.state["failed"].append(video_name)
|
||||||
|
return False, video_name
|
||||||
|
|
||||||
|
def process_all_videos(self):
|
||||||
|
"""处理所有视频文件"""
|
||||||
|
if not self.video_files:
|
||||||
|
print("没有找到需要处理的视频文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(self.video_files)} 个待处理视频文件")
|
||||||
|
print(f"输出目录: {self.output_dir}")
|
||||||
|
print(f"帧大小: {self.frame_size}x{self.frame_size}")
|
||||||
|
print(f"帧率: {self.fps} fps")
|
||||||
|
print(f"并发线程数: {self.num_workers}")
|
||||||
|
|
||||||
|
if self.state["completed"]:
|
||||||
|
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
if self.state["start_time"] is None:
|
||||||
|
self.state["start_time"] = time.time()
|
||||||
|
|
||||||
|
# 创建进度条
|
||||||
|
with tqdm(total=len(self.video_files), desc="处理视频", unit="个") as pbar:
|
||||||
|
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_video = {
|
||||||
|
executor.submit(self._process_video, video_path): video_path
|
||||||
|
for video_path in self.video_files
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理完成的任务
|
||||||
|
for future in as_completed(future_to_video):
|
||||||
|
video_path = future_to_video[future]
|
||||||
|
try:
|
||||||
|
success, video_name = future.result()
|
||||||
|
if success:
|
||||||
|
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
|
||||||
|
else:
|
||||||
|
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理 {video_path.name} 时发生异常: {e}")
|
||||||
|
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
self.state["total_processed"] += 1
|
||||||
|
|
||||||
|
# 定期保存状态
|
||||||
|
if self.state["total_processed"] % 5 == 0:
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
|
# 最终保存状态
|
||||||
|
self._save_state()
|
||||||
|
|
||||||
|
# 打印处理结果
|
||||||
|
self._print_summary()
|
||||||
|
|
||||||
|
def _print_summary(self):
|
||||||
|
"""打印处理摘要"""
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("处理完成摘要:")
|
||||||
|
print(f"总处理视频数: {len(self.state['completed'])}")
|
||||||
|
print(f"失败视频数: {len(self.state['failed'])}")
|
||||||
|
|
||||||
|
if self.state["failed"]:
|
||||||
|
print("\n失败的视频:")
|
||||||
|
for video_name in self.state["failed"]:
|
||||||
|
print(f" - {video_name}")
|
||||||
|
|
||||||
|
if self.state["start_time"]:
|
||||||
|
elapsed_time = time.time() - self.state["start_time"]
|
||||||
|
print(f"\n总耗时: {elapsed_time:.2f} 秒")
|
||||||
|
if self.state["total_processed"] > 0:
|
||||||
|
avg_time = elapsed_time / self.state["total_processed"]
|
||||||
|
print(f"平均每个视频: {avg_time:.2f} 秒")
|
||||||
|
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
parser = argparse.ArgumentParser(description="视频预处理脚本")
|
||||||
|
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
|
||||||
|
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
|
||||||
|
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
|
||||||
|
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
|
||||||
|
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
|
||||||
|
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 检查输入目录
|
||||||
|
if not Path(args.input_dir).exists():
|
||||||
|
print(f"错误: 输入目录不存在: {args.input_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 检查FFmpeg是否可用
|
||||||
|
try:
|
||||||
|
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
print("错误: FFmpeg未安装或不在PATH中")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 创建预处理器并开始处理
|
||||||
|
preprocessor = VideoPreprocessor(
|
||||||
|
input_dir=args.input_dir,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
frame_size=args.size,
|
||||||
|
fps=args.fps,
|
||||||
|
num_workers=args.workers,
|
||||||
|
quality=args.quality,
|
||||||
|
resume=not args.no_resume
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
preprocessor.process_all_videos()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n用户中断处理,状态已保存")
|
||||||
|
preprocessor._save_state()
|
||||||
|
print("可以使用相同命令恢复处理")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n处理过程中发生错误: {e}")
|
||||||
|
preprocessor._save_state()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user