""" Main training script for SwiftFormerTemporal frame prediction """ import argparse import datetime import numpy as np import time import torch import torch.nn as nn import torch.backends.cudnn as cudnn import json import os from pathlib import Path from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma from util import * from models import * from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from util.video_dataset import VideoFrameDataset from util.frame_losses import MultiTaskLoss # Try to import TensorBoard try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: try: from tensorboardX import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: TENSORBOARD_AVAILABLE = False def get_args_parser(): parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', add_help=False) # Dataset parameters parser.add_argument('--data-path', default='./videos', type=str, help='Path to video dataset') parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'], type=str, help='Dataset type') parser.add_argument('--num-frames', default=3, type=int, help='Number of input frames (T)') parser.add_argument('--frame-size', default=224, type=int, help='Input frame size') parser.add_argument('--max-interval', default=4, type=int, help='Maximum interval between consecutive frames') # Model parameters parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--use-representation-head', action='store_true', help='Use representation head for pose/velocity prediction') parser.add_argument('--representation-dim', default=128, type=int, help='Dimension of representation vector') # Training parameters parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--epochs', default=100, type=int) # Optimizer parameters (required by timm's create_optimizer) parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='agc', help='Gradient clipping mode. One of ("norm", "value", "agc")') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (default: 1e-3)') # Learning rate schedule parameters (required by timm's create_scheduler) parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Loss parameters parser.add_argument('--frame-weight', type=float, default=1.0, help='Weight for frame prediction loss') parser.add_argument('--contrastive-weight', type=float, default=0.1, help='Weight for contrastive loss') # parser.add_argument('--l1-weight', type=float, default=1.0, # help='Weight for L1 loss') # parser.add_argument('--ssim-weight', type=float, default=0.1, # help='Weight for SSIM loss') parser.add_argument('--no-contrastive', action='store_true', help='Disable contrastive loss') parser.add_argument('--no-ssim', action='store_true', help='Disable SSIM loss') # System parameters parser.add_argument('--output-dir', default='./output', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # Distributed training parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') # TensorBoard logging parser.add_argument('--tensorboard-logdir', default='./runs', type=str, help='TensorBoard log directory') parser.add_argument('--log-images', action='store_true', help='Log sample images to TensorBoard') parser.add_argument('--image-log-freq', default=100, type=int, help='Frequency of logging images (in iterations)') return parser def build_dataset(is_train, args): """Build video frame dataset""" dataset = VideoFrameDataset( root_dir=args.data_path, num_frames=args.num_frames, frame_size=args.frame_size, is_train=is_train, max_interval=args.max_interval ) return dataset def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # Build datasets dataset_train = build_dataset(is_train=True, args=args) dataset_val = build_dataset(is_train=False, args=args) # Create samplers if args.distributed: sampler_train = torch.utils.data.DistributedSampler(dataset_train) sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) # Create model print(f"Creating model: {args.model}") model_kwargs = { 'num_frames': args.num_frames, 'use_representation_head': args.use_representation_head, 'representation_dim': args.representation_dim, } if args.model == 'SwiftFormerTemporal_XS': model = SwiftFormerTemporal_XS(**model_kwargs) elif args.model == 'SwiftFormerTemporal_S': model = SwiftFormerTemporal_S(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L1': model = SwiftFormerTemporal_L1(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L3': model = SwiftFormerTemporal_L3(**model_kwargs) else: raise ValueError(f"Unknown model: {args.model}") model.to(device) # Model EMA model_ema = None if hasattr(args, 'model_ema') and args.model_ema: model_ema = ModelEma( model, decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999, device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '', resume='') # Distributed training model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of parameters: {n_parameters}') # Create optimizer optimizer = create_optimizer(args, model_without_ddp) # Create loss scaler loss_scaler = NativeScaler() # Create scheduler lr_scheduler, _ = create_scheduler(args, optimizer) # Create loss function - simple MSE for Y channel prediction class MSELossWrapper(nn.Module): def __init__(self): super().__init__() self.mse = nn.MSELoss() def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): loss = self.mse(pred_frame, target_frame) loss_dict = {'mse': loss} return loss, loss_dict criterion = MSELossWrapper() # Resume from checkpoint output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if model_ema is not None: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) # Initialize TensorBoard writer writer = None if TENSORBOARD_AVAILABLE and utils.is_main_process(): from datetime import datetime # Create log directory with timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}") os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) print(f"TensorBoard logs will be saved to: {log_dir}") print(f"To view logs, run: tensorboard --logdir={log_dir}") elif not TENSORBOARD_AVAILABLE and utils.is_main_process(): print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.") print("Training will continue without TensorBoard logging.") if args.eval: test_stats = evaluate(data_loader_val, model, criterion, device) print(f"Test stats: {test_stats}") return print(f"Start training for {args.epochs} epochs") start_time = time.time() # Global step counter for TensorBoard global_step = 0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats, global_step = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, model_ema=model_ema, writer=writer, global_step=global_step, args=args ) lr_scheduler.step(epoch) # Save checkpoint if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1): checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema) if model_ema else None, 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) # Evaluate if epoch % 5 == 0 or epoch == args.epochs - 1: test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch) print(f"Epoch {epoch}: Test stats: {test_stats}") # Log stats to text file log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f'Training time {total_time_str}') # Close TensorBoard writer if writer is not None: writer.close() print(f"TensorBoard logs saved to: {writer.log_dir}") def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, clip_grad=0, clip_mode='norm', model_ema=None, writer=None, global_step=0, args=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = f'Epoch: [{epoch}]' print_freq = 10 for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate( metric_logger.log_every(data_loader, print_freq, header)): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Forward pass with torch.cuda.amp.autocast(): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, representations, temporal_indices ) loss_value = loss.item() if not torch.isfinite(torch.tensor(loss_value)): print(f"Loss is {loss_value}, stopping training") raise ValueError(f"Loss is {loss_value}") optimizer.zero_grad() loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, parameters=model.parameters()) torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) # Log to TensorBoard if writer is not None: # Log scalar metrics every iteration writer.add_scalar('train/loss', loss_value, global_step) writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step) # Log individual loss components for k, v in loss_dict.items(): if torch.is_tensor(v): writer.add_scalar(f'train/{k}', v.item(), global_step) else: writer.add_scalar(f'train/{k}', v, global_step) # Log images periodically if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: with torch.no_grad(): # Take first sample from batch for visualization pred_vis, _ = model(input_frames[:1]) # Convert to appropriate format for TensorBoard # Assuming frames are in [B, C, H, W] format writer.add_images('train/input', input_frames[:1], global_step) writer.add_images('train/target', target_frames[:1], global_step) writer.add_images('train/predicted', pred_vis[:1], global_step) # Update metrics metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) global_step += 1 metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) # Log epoch-level metrics if writer is not None: for k, meter in metric_logger.meters.items(): writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step @torch.no_grad() def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Compute output with torch.cuda.amp.autocast(): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, representations, temporal_indices ) # Update metrics metric_logger.update(loss=loss.item()) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.synchronize_between_processes() print('* Test stats:', metric_logger) # Log validation metrics to TensorBoard if writer is not None: for k, meter in metric_logger.meters.items(): writer.add_scalar(f'val/{k}', meter.global_avg, epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} if __name__ == '__main__': parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)