test modify swiftformer to temporal input

This commit is contained in:
2026-01-07 11:03:33 +08:00
parent 4aa6cd6752
commit 7e9564ef20
6 changed files with 1074 additions and 0 deletions

373
main_temporal.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Main training script for SwiftFormerTemporal frame prediction
"""
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', add_help=False)
# Dataset parameters
parser.add_argument('--data-path', default='./videos', type=str,
help='Path to video dataset')
parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'],
type=str, help='Dataset type')
parser.add_argument('--num-frames', default=3, type=int,
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
help='Disable SSIM loss')
# System parameters
parser.add_argument('--output-dir', default='./output',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# Distributed training
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# Build datasets
dataset_train = build_dataset(is_train=True, args=args)
dataset_val = build_dataset(is_train=False, args=args)
# Create samplers
if args.distributed:
sampler_train = torch.utils.data.DistributedSampler(dataset_train)
sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
# Create model
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"Unknown model: {args.model}")
model.to(device)
# Model EMA
model_ema = None
if hasattr(args, 'model_ema') and args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999,
device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '',
resume='')
# Distributed training
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {n_parameters}')
# Create optimizer
optimizer = create_optimizer(args, model_without_ddp)
# Create loss scaler
loss_scaler = NativeScaler()
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Resume from checkpoint
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema is not None:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
)
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema) if model_ema else None,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
loss_value = loss.item()
if not torch.isfinite(torch.tensor(loss_value)):
print(f"Loss is {loss_value}, stopping training")
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)