""" Main training script for SwiftFormerTemporal frame prediction """ import argparse import datetime import numpy as np import time import torch import torch.nn as nn import torch.backends.cudnn as cudnn import json import os from pathlib import Path from timm.scheduler import create_scheduler from timm.optim import create_optimizer from timm.utils import NativeScaler, get_state_dict, ModelEma from util import * from models import * from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from util.video_dataset import VideoFrameDataset # from util.frame_losses import MultiTaskLoss # Try to import TensorBoard try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: TENSORBOARD_AVAILABLE = False def get_args_parser(): parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', add_help=False) # Dataset parameters parser.add_argument('--data-path', default='./videos', type=str, help='Path to video dataset') parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'], type=str, help='Dataset type') parser.add_argument('--num-frames', default=3, type=int, help='Number of input frames (T)') parser.add_argument('--frame-size', default=224, type=int, help='Input frame size') parser.add_argument('--max-interval', default=4, type=int, help='Maximum interval between consecutive frames') # Model parameters parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='Name of model to train') # Training parameters parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--epochs', default=100, type=int) # Optimizer parameters (required by timm's create_optimizer) parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='agc', help='Gradient clipping mode. One of ("norm", "value", "agc")') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 1e-3)') # Learning rate schedule parameters (required by timm's create_scheduler) parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Loss parameters parser.add_argument('--frame-weight', type=float, default=1.0, help='Weight for frame prediction loss') parser.add_argument('--contrastive-weight', type=float, default=0.1, help='Weight for contrastive loss') # parser.add_argument('--l1-weight', type=float, default=1.0, # help='Weight for L1 loss') # parser.add_argument('--ssim-weight', type=float, default=0.1, # help='Weight for SSIM loss') parser.add_argument('--no-contrastive', action='store_true', help='Disable contrastive loss') parser.add_argument('--no-ssim', action='store_true', help='Disable SSIM loss') # System parameters parser.add_argument('--output-dir', default='./output', help='path where to save, empty for no saving') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # Distributed training parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') # TensorBoard logging parser.add_argument('--tensorboard-logdir', default='./runs', type=str, help='TensorBoard log directory') parser.add_argument('--log-images', action='store_true', help='Log sample images to TensorBoard') parser.add_argument('--image-log-freq', default=100, type=int, help='Frequency of logging images (in iterations)') return parser def build_dataset(is_train, args): """Build video frame dataset""" dataset = VideoFrameDataset( root_dir=args.data_path, num_frames=args.num_frames, frame_size=args.frame_size, is_train=is_train, max_interval=args.max_interval ) return dataset def main(args): utils.init_distributed_mode(args) print(args) device = torch.device(args.device) # Fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True # Build datasets dataset_train = build_dataset(is_train=True, args=args) dataset_val = build_dataset(is_train=False, args=args) # Create samplers if args.distributed: sampler_train = torch.utils.data.DistributedSampler(dataset_train) sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) sampler_val = torch.utils.data.SequentialSampler(dataset_val) data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False ) # Create model print(f"Creating model: {args.model}") model_kwargs = { 'num_frames': args.num_frames, } if args.model == 'SwiftFormerTemporal_XS': model = SwiftFormerTemporal_XS(**model_kwargs) elif args.model == 'SwiftFormerTemporal_S': model = SwiftFormerTemporal_S(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L1': model = SwiftFormerTemporal_L1(**model_kwargs) elif args.model == 'SwiftFormerTemporal_L3': model = SwiftFormerTemporal_L3(**model_kwargs) else: raise ValueError(f"Unknown model: {args.model}") model.to(device) # Model EMA model_ema = None if hasattr(args, 'model_ema') and args.model_ema: model_ema = ModelEma( model, decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999, device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '', resume='') # Distributed training model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Number of parameters: {n_parameters}') # Create optimizer optimizer = create_optimizer(args, model_without_ddp) # Create loss scaler loss_scaler = NativeScaler() # Create scheduler lr_scheduler, _ = create_scheduler(args, optimizer) # Create loss function - simple MSE for Y channel prediction class MSELossWrapper(nn.Module): def __init__(self): super().__init__() self.mse = nn.MSELoss() def forward(self, pred_frame, target_frame, temporal_indices=None): loss = self.mse(pred_frame, target_frame) loss_dict = {'mse': loss} return loss, loss_dict criterion = MSELossWrapper() # Resume from checkpoint output_dir = Path(args.output_dir) if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if model_ema is not None: utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) # Initialize TensorBoard writer writer = None if TENSORBOARD_AVAILABLE and utils.is_main_process(): from datetime import datetime # Create log directory with timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}") os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) print(f"TensorBoard logs will be saved to: {log_dir}") print(f"To view logs, run: tensorboard --logdir={log_dir}") elif not TENSORBOARD_AVAILABLE and utils.is_main_process(): print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.") print("Training will continue without TensorBoard logging.") if args.eval: test_stats = evaluate(data_loader_val, model, criterion, device) print(f"Test stats: {test_stats}") return print(f"Start training for {args.epochs} epochs") start_time = time.time() # Global step counter for TensorBoard global_step = 0 for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats, global_step = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, model_ema=model_ema, writer=writer, global_step=global_step, args=args ) lr_scheduler.step(epoch) # Save checkpoint if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1): checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'model_ema': get_state_dict(model_ema) if model_ema else None, 'scaler': loss_scaler.state_dict(), 'args': args, }, checkpoint_path) # Evaluate if epoch % 5 == 0 or epoch == args.epochs - 1: test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch) print(f"Epoch {epoch}: Test stats: {test_stats}") # Log stats to text file log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters } if args.output_dir and utils.is_main_process(): with (output_dir / "log.txt").open("a") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f'Training time {total_time_str}') # Close TensorBoard writer if writer is not None: writer.close() print(f"TensorBoard logs saved to: {writer.log_dir}") def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, clip_grad=None, clip_mode='norm', model_ema=None, writer=None, global_step=0, args=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = f'Epoch: [{epoch}]' print_freq = 10 # 添加诊断指标 metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate( metric_logger.log_every(data_loader, print_freq, header)): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Forward pass with torch.amp.autocast(device_type='cuda'): pred_frames = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, temporal_indices ) loss_value = loss.item() if not torch.isfinite(torch.tensor(loss_value)): print(f"Loss is {loss_value}, stopping training") raise ValueError(f"Loss is {loss_value}") optimizer.zero_grad() loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, parameters=model.parameters()) torch.cuda.synchronize() if model_ema is not None: model_ema.update(model) # 计算诊断指标 pred_mean = pred_frames.mean().item() pred_std = pred_frames.std().item() # 计算梯度范数 total_grad_norm = 0.0 for param in model.parameters(): if param.grad is not None: total_grad_norm += param.grad.norm().item() # 记录诊断指标 metric_logger.update(pred_mean=pred_mean) metric_logger.update(pred_std=pred_std) metric_logger.update(grad_norm=total_grad_norm) # # 每50个批次打印一次BatchNorm统计 if batch_idx % 50 == 0: print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}") # # 检查一个BatchNorm层的运行统计 # for name, module in model.named_modules(): # if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: # print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") # break # Log to TensorBoard if writer is not None: # Log scalar metrics every iteration writer.add_scalar('train/loss', loss_value, global_step) writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step) # Log individual loss components for k, v in loss_dict.items(): if torch.is_tensor(v): writer.add_scalar(f'train/{k}', v.item(), global_step) else: writer.add_scalar(f'train/{k}', v, global_step) # Log diagnostic metrics writer.add_scalar('train/pred_mean', pred_mean, global_step) writer.add_scalar('train/pred_std', pred_std, global_step) writer.add_scalar('train/grad_norm', total_grad_norm, global_step) # Log images periodically if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: with torch.no_grad(): # Take first sample from batch for visualization pred_vis = model(input_frames[:1]) # Convert to appropriate format for TensorBoard # Assuming frames are in [B, C, H, W] format writer.add_images('train/input', input_frames[:1], global_step) writer.add_images('train/target', target_frames[:1], global_step) writer.add_images('train/predicted', pred_vis[:1], global_step) # Update metrics metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) global_step += 1 metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) # Log epoch-level metrics if writer is not None: for k, meter in metric_logger.meters.items(): writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step @torch.no_grad() def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' # 添加诊断指标 metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Compute output with torch.amp.autocast(device_type='cuda'): pred_frames = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, temporal_indices ) # 计算诊断指标 pred_mean = pred_frames.mean().item() pred_std = pred_frames.std().item() target_mean = target_frames.mean().item() target_std = target_frames.std().item() # 更新诊断指标 metric_logger.update(pred_mean=pred_mean) metric_logger.update(pred_std=pred_std) metric_logger.update(target_mean=target_mean) metric_logger.update(target_std=target_std) # # 第一个批次打印详细诊断信息 # if batch_idx == 0: # print(f"[评估诊断] 批次 0:") # print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]") # print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}") # print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]") # print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}") # # 检查BatchNorm运行统计 # for name, module in model.named_modules(): # if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: # print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") # if module.running_var[0].item() < 1e-6: # print(f" 警告: BatchNorm运行方差接近零!") # break # Update metrics metric_logger.update(loss=loss.item()) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.synchronize_between_processes() print('* Test stats:', metric_logger) # Log validation metrics to TensorBoard if writer is not None: for k, meter in metric_logger.meters.items(): writer.add_scalar(f'val/{k}', meter.global_avg, epoch) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} if __name__ == '__main__': parser = argparse.ArgumentParser( 'SwiftFormerTemporal training script', parents=[get_args_parser()]) args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)