""" Main training script for SwiftFormerTemporal frame prediction """ import argparse import datetime import numpy as np import time import torch import torch.backends.cudnn as cudnn import json from pathlib import Path from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma from util import * from models import * from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.frame_losses import MultiTaskLoss def get_args_parser(): parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', add_help=False) # Dataset parameters parser.add_argument('--data-path', default='./videos', type=str, help='Path to video dataset') parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'], type=str, help='Dataset type') parser.add_argument('--num-frames', default=3, type=int, help='Number of input frames (T)') parser.add_argument('--frame-size', default=224, type=int, help='Input frame size') parser.add_argument('--max-interval', default=1, type=int, help='Maximum interval between consecutive frames') # Model parameters parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--use-representation-head', action='store_true', help='Use representation head for pose/velocity prediction') parser.add_argument('--representation-dim', default=128, type=int, help='Dimension of representation vector') # Training parameters parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (default: 1e-3)') parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') # Loss parameters parser.add_argument('--frame-weight', type=float, default=1.0, help='Weight for frame prediction loss') parser.add_argument('--contrastive-weight', type=float, default=0.1, help='Weight for contrastive loss') parser.add_argument('--l1-weight', type=float, default=1.0, help='Weight for L1 loss') parser.add_argument('--ssim-weight', type=float, default=0.1, help='Weight for SSIM loss') parser.add_argument('--no-contrastive', action='store_true', help='Disable contrastive loss') parser.add_argument('--no-ssim', action='store_true', help='Disable SSIM loss') # System parameters parser.add_argument('--output-dir', default='./output', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # Distributed training parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') return parser def build_dataset(is_train, args): """Build video frame dataset""" if args.dataset_type == 'synthetic': dataset = SyntheticVideoDataset( num_samples=1000 if is_train else 200, num_frames=args.num_frames, frame_size=args.frame_size, is_train=is_train ) else: dataset = VideoFrameDataset( root_dir=args.data_path, num_frames=args.num_frames, frame_size=args.frame_size, is_train=is_train, max_interval=args.max_interval ) return dataset def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # Build datasets dataset_train = build_dataset(is_train=True, args=args) dataset_val = build_dataset(is_train=False, args=args) # Create samplers if args.distributed: sampler_train = torch.utils.data.DistributedSampler(dataset_train) sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) # Create model print(f"Creating model: {args.model}") model_kwargs = { 'num_frames': args.num_frames, 'use_representation_head': args.use_representation_head, 'representation_dim': args.representation_dim, } if args.model == 'SwiftFormerTemporal_XS': model = SwiftFormerTemporal_XS(**model_kwargs) elif args.model == 'SwiftFormerTemporal_S': model = SwiftFormerTemporal_S(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L1': model = SwiftFormerTemporal_L1(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L3': model = SwiftFormerTemporal_L3(**model_kwargs) else: raise ValueError(f"Unknown model: {args.model}") model.to(device) # Model EMA model_ema = None if hasattr(args, 'model_ema') and args.model_ema: model_ema = ModelEma( model, decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999, device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '', resume='') # Distributed training model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of parameters: {n_parameters}') # Create optimizer optimizer = create_optimizer(args, model_without_ddp) # Create loss scaler loss_scaler = NativeScaler() # Create scheduler lr_scheduler, _ = create_scheduler(args, optimizer) # Create loss function criterion = MultiTaskLoss( frame_weight=args.frame_weight, contrastive_weight=args.contrastive_weight, l1_weight=args.l1_weight, ssim_weight=args.ssim_weight, use_contrastive=not args.no_contrastive ) # Resume from checkpoint output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if model_ema is not None: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) if args.eval: test_stats = evaluate(data_loader_val, model, criterion, device) print(f"Test stats: {test_stats}") return print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, model_ema=model_ema ) lr_scheduler.step(epoch) # Save checkpoint if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema) if model_ema else None, 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) # Evaluate if epoch % 5 == 0 or epoch == args.epochs - 1: test_stats = evaluate(data_loader_val, model, criterion, device) print(f"Epoch {epoch}: Test stats: {test_stats}") # Log stats log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f'Training time {total_time_str}') def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, clip_grad=0, clip_mode='norm', model_ema=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = f'Epoch: [{epoch}]' print_freq = 10 for input_frames, target_frames, temporal_indices in metric_logger.log_every( data_loader, print_freq, header): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Forward pass with torch.cuda.amp.autocast(): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, representations, temporal_indices ) loss_value = loss.item() if not torch.isfinite(torch.tensor(loss_value)): print(f"Loss is {loss_value}, stopping training") raise ValueError(f"Loss is {loss_value}") optimizer.zero_grad() loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, parameters=model.parameters()) torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) # Update metrics metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} @torch.no_grad() def evaluate(data_loader, model, criterion, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Compute output with torch.cuda.amp.autocast(): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, representations, temporal_indices ) # Update metrics metric_logger.update(loss=loss.item()) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.synchronize_between_processes() print('* Test stats:', metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} if __name__ == '__main__': parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)