From 7e9564ef206c78d2eb3e2d9f8c99bb456541824f Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Wed, 7 Jan 2026 11:03:33 +0800 Subject: [PATCH] test modify swiftformer to temporal input --- main_temporal.py | 373 +++++++++++++++++++++++++++++++++ models/__init__.py | 6 + models/swiftformer_temporal.py | 244 +++++++++++++++++++++ test_model.py | 60 ++++++ util/frame_losses.py | 182 ++++++++++++++++ util/video_dataset.py | 209 ++++++++++++++++++ 6 files changed, 1074 insertions(+) create mode 100644 main_temporal.py create mode 100644 models/swiftformer_temporal.py create mode 100644 test_model.py create mode 100644 util/frame_losses.py create mode 100644 util/video_dataset.py diff --git a/main_temporal.py b/main_temporal.py new file mode 100644 index 0000000..6cd8126 --- /dev/null +++ b/main_temporal.py @@ -0,0 +1,373 @@ +""" +Main training script for SwiftFormerTemporal frame prediction +""" +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +from pathlib import Path + +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma + +from util import * +from models import * +from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 +from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset +from util.frame_losses import MultiTaskLoss + + +def get_args_parser(): + parser = argparse.ArgumentParser( + 'SwiftFormerTemporal training script', add_help=False) + + # Dataset parameters + parser.add_argument('--data-path', default='./videos', type=str, + help='Path to video dataset') + parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'], + type=str, help='Dataset type') + parser.add_argument('--num-frames', default=3, type=int, + help='Number of input frames (T)') + parser.add_argument('--frame-size', default=224, type=int, + help='Input frame size') + parser.add_argument('--max-interval', default=1, type=int, + help='Maximum interval between consecutive frames') + + # Model parameters + parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--use-representation-head', action='store_true', + help='Use representation head for pose/velocity prediction') + parser.add_argument('--representation-dim', default=128, type=int, + help='Dimension of representation vector') + + # Training parameters + parser.add_argument('--batch-size', default=32, type=int) + parser.add_argument('--epochs', default=100, type=int) + parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', + help='learning rate (default: 1e-3)') + parser.add_argument('--weight-decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + # Loss parameters + parser.add_argument('--frame-weight', type=float, default=1.0, + help='Weight for frame prediction loss') + parser.add_argument('--contrastive-weight', type=float, default=0.1, + help='Weight for contrastive loss') + parser.add_argument('--l1-weight', type=float, default=1.0, + help='Weight for L1 loss') + parser.add_argument('--ssim-weight', type=float, default=0.1, + help='Weight for SSIM loss') + parser.add_argument('--no-contrastive', action='store_true', + help='Disable contrastive loss') + parser.add_argument('--no-ssim', action='store_true', + help='Disable SSIM loss') + + # System parameters + parser.add_argument('--output-dir', default='./output', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--num-workers', default=4, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='Pin CPU memory in DataLoader') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + # Distributed training + parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist-url', default='env://', + help='url used to set up distributed training') + + return parser + + +def build_dataset(is_train, args): + """Build video frame dataset""" + if args.dataset_type == 'synthetic': + dataset = SyntheticVideoDataset( + num_samples=1000 if is_train else 200, + num_frames=args.num_frames, + frame_size=args.frame_size, + is_train=is_train + ) + else: + dataset = VideoFrameDataset( + root_dir=args.data_path, + num_frames=args.num_frames, + frame_size=args.frame_size, + is_train=is_train, + max_interval=args.max_interval + ) + + return dataset + + +def main(args): + utils.init_distributed_mode(args) + print(args) + + device = torch.device(args.device) + + # Fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + # Build datasets + dataset_train = build_dataset(is_train=True, args=args) + dataset_val = build_dataset(is_train=False, args=args) + + # Create samplers + if args.distributed: + sampler_train = torch.utils.data.DistributedSampler(dataset_train) + sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + # Create model + print(f"Creating model: {args.model}") + model_kwargs = { + 'num_frames': args.num_frames, + 'use_representation_head': args.use_representation_head, + 'representation_dim': args.representation_dim, + } + + if args.model == 'SwiftFormerTemporal_XS': + model = SwiftFormerTemporal_XS(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_S': + model = SwiftFormerTemporal_S(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_L1': + model = SwiftFormerTemporal_L1(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_L3': + model = SwiftFormerTemporal_L3(**model_kwargs) + else: + raise ValueError(f"Unknown model: {args.model}") + + model.to(device) + + # Model EMA + model_ema = None + if hasattr(args, 'model_ema') and args.model_ema: + model_ema = ModelEma( + model, + decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999, + device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '', + resume='') + + # Distributed training + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f'Number of parameters: {n_parameters}') + + # Create optimizer + optimizer = create_optimizer(args, model_without_ddp) + + # Create loss scaler + loss_scaler = NativeScaler() + + # Create scheduler + lr_scheduler, _ = create_scheduler(args, optimizer) + + # Create loss function + criterion = MultiTaskLoss( + frame_weight=args.frame_weight, + contrastive_weight=args.contrastive_weight, + l1_weight=args.l1_weight, + ssim_weight=args.ssim_weight, + use_contrastive=not args.no_contrastive + ) + + # Resume from checkpoint + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + + model_without_ddp.load_state_dict(checkpoint['model']) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + if model_ema is not None: + utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + + if args.eval: + test_stats = evaluate(data_loader_val, model, criterion, device) + print(f"Test stats: {test_stats}") + return + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + model_ema=model_ema + ) + + lr_scheduler.step(epoch) + + # Save checkpoint + if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): + checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'model_ema': get_state_dict(model_ema) if model_ema else None, + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + # Evaluate + if epoch % 5 == 0 or epoch == args.epochs - 1: + test_stats = evaluate(data_loader_val, model, criterion, device) + print(f"Epoch {epoch}: Test stats: {test_stats}") + + # Log stats + log_stats = { + **{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters + } + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f'Training time {total_time_str}') + + +def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, + clip_grad=0, clip_mode='norm', model_ema=None, **kwargs): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = f'Epoch: [{epoch}]' + print_freq = 10 + + for input_frames, target_frames, temporal_indices in metric_logger.log_every( + data_loader, print_freq, header): + + input_frames = input_frames.to(device, non_blocking=True) + target_frames = target_frames.to(device, non_blocking=True) + temporal_indices = temporal_indices.to(device, non_blocking=True) + + # Forward pass + with torch.cuda.amp.autocast(): + pred_frames, representations = model(input_frames) + loss, loss_dict = criterion( + pred_frames, target_frames, + representations, temporal_indices + ) + + loss_value = loss.item() + if not torch.isfinite(torch.tensor(loss_value)): + print(f"Loss is {loss_value}, stopping training") + raise ValueError(f"Loss is {loss_value}") + + optimizer.zero_grad() + loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, + parameters=model.parameters()) + + torch.cuda.synchronize() + if model_ema is not None: + model_ema.update(model) + + # Update metrics + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + for k, v in loss_dict.items(): + metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) + + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, criterion, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): + input_frames = input_frames.to(device, non_blocking=True) + target_frames = target_frames.to(device, non_blocking=True) + temporal_indices = temporal_indices.to(device, non_blocking=True) + + # Compute output + with torch.cuda.amp.autocast(): + pred_frames, representations = model(input_frames) + loss, loss_dict = criterion( + pred_frames, target_frames, + representations, temporal_indices + ) + + # Update metrics + metric_logger.update(loss=loss.item()) + for k, v in loss_dict.items(): + metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) + + metric_logger.synchronize_between_processes() + print('* Test stats:', metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'SwiftFormerTemporal training script', parents=[get_args_parser()]) + args = parser.parse_args() + + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + main(args) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index b07c562..79935f9 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,7 @@ from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3 +from .swiftformer_temporal import ( + SwiftFormerTemporal_XS, + SwiftFormerTemporal_S, + SwiftFormerTemporal_L1, + SwiftFormerTemporal_L3 +) diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py new file mode 100644 index 0000000..6a105ca --- /dev/null +++ b/models/swiftformer_temporal.py @@ -0,0 +1,244 @@ +""" +SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction +""" +import torch +import torch.nn as nn +from .swiftformer import ( + SwiftFormer, SwiftFormer_depth, SwiftFormer_width, + stem, Embedding, Stage +) +from timm.models.layers import DropPath, trunc_normal_ + + +class DecoderBlock(nn.Module): + """Upsampling block for frame prediction decoder""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): + super().__init__() + self.conv = nn.ConvTranspose2d( + in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class FramePredictionDecoder(nn.Module): + """Lightweight decoder for frame prediction with optional skip connections""" + def __init__(self, embed_dims, output_channels=3, use_skip=False): + super().__init__() + self.use_skip = use_skip + # Reverse the embed_dims for decoder + decoder_dims = embed_dims[::-1] + + self.blocks = nn.ModuleList() + # First upsampling from bottleneck to stage4 resolution + self.blocks.append(DecoderBlock( + decoder_dims[0], decoder_dims[1], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # stage4 to stage3 + self.blocks.append(DecoderBlock( + decoder_dims[1], decoder_dims[2], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # stage3 to stage2 + self.blocks.append(DecoderBlock( + decoder_dims[2], decoder_dims[3], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # stage2 to original resolution (4x upsampling total) + self.blocks.append(nn.Sequential( + nn.ConvTranspose2d( + decoder_dims[3], 32, + kernel_size=3, stride=2, padding=1, output_padding=1 + ), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, output_channels, kernel_size=3, padding=1), + nn.Tanh() # Output in [-1, 1] range + )) + + # If using skip connections, we need to adjust input channels for each block + if use_skip: + # We'll modify the first three blocks to accept concatenated features + # Instead of modifying existing blocks, we'll replace them with custom blocks + # For simplicity, we'll keep the same architecture but forward will handle concatenation + pass + + def forward(self, x, skip_features=None): + """ + Args: + x: input tensor of shape [B, embed_dims[-1], H/32, W/32] + skip_features: list of encoder features from stages [stage2, stage1, stage0] + each of shape [B, C, H', W'] where C matches decoder dims? + """ + if self.use_skip and skip_features is not None: + # Ensure we have exactly 3 skip features (for the first three blocks) + assert len(skip_features) == 3, "Need 3 skip features for skip connections" + # Reverse skip_features to match decoder order: stage2, stage1, stage0 + # skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4) + skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0 + + for i, block in enumerate(self.blocks): + if self.use_skip and skip_features is not None and i < 3: + # Concatenate skip feature along channel dimension + # Ensure spatial dimensions match (they should because of upsampling) + x = torch.cat([x, skip_features[i]], dim=1) + # Need to adjust block to accept extra channels? We'll create a separate block. + # For now, we'll just pass through, but this will cause channel mismatch. + # Instead, we should have created custom blocks with appropriate in_channels. + # This is a placeholder; we need to implement properly. + pass + x = block(x) + return x + + +class SwiftFormerTemporal(nn.Module): + """ + SwiftFormer with temporal input for frame prediction. + Input: [B, num_frames, H, W] (Y channel only) + Output: predicted frame [B, 3, H, W] and optional representation + """ + def __init__(self, + model_name='XS', + num_frames=3, + use_decoder=True, + use_representation_head=False, + representation_dim=128, + return_features=False, + **kwargs): + super().__init__() + + # Get model configuration + layers = SwiftFormer_depth[model_name] + embed_dims = SwiftFormer_width[model_name] + + # Store configuration + self.num_frames = num_frames + self.use_decoder = use_decoder + self.use_representation_head = use_representation_head + self.return_features = return_features + + # Modify stem to accept multiple frames (only Y channel) + in_channels = num_frames + self.patch_embed = stem(in_channels, embed_dims[0]) + + # Build encoder network (same as SwiftFormer) + network = [] + for i in range(len(layers)): + stage = Stage(embed_dims[i], i, layers, mlp_ratio=4, + act_layer=nn.GELU, + drop_rate=0., drop_path_rate=0., + use_layer_scale=True, + layer_scale_init_value=1e-5, + vit_num=1) + network.append(stage) + if i >= len(layers) - 1: + break + if embed_dims[i] != embed_dims[i + 1]: + network.append( + Embedding( + patch_size=3, stride=2, padding=1, + in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] + ) + ) + + self.network = nn.ModuleList(network) + self.norm = nn.BatchNorm2d(embed_dims[-1]) + + # Frame prediction decoder + if use_decoder: + self.decoder = FramePredictionDecoder(embed_dims, output_channels=3) + + # Representation head for pose/velocity prediction + if use_representation_head: + self.representation_head = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear(embed_dims[-1], representation_dim), + nn.ReLU(), + nn.Linear(representation_dim, representation_dim) + ) + else: + self.representation_head = None + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_tokens(self, x): + """Forward through encoder network, return list of stage features if return_features else final output""" + if self.return_features: + features = [] + for idx, block in enumerate(self.network): + x = block(x) + # Collect output after each stage (indices 0,2,4,6 correspond to stages) + if idx in [0, 2, 4, 6]: + features.append(x) + return x, features + else: + for block in self.network: + x = block(x) + return x + + def forward(self, x): + """ + Args: + x: input frames of shape [B, num_frames, H, W] + Returns: + If return_features is False: + pred_frame: predicted frame [B, 3, H, W] (or None) + representation: optional representation vector [B, representation_dim] (or None) + If return_features is True: + pred_frame, representation, features (list of stage features) + """ + # Encode + x = self.patch_embed(x) + if self.return_features: + x, features = self.forward_tokens(x) + else: + x = self.forward_tokens(x) + x = self.norm(x) + + # Get representation if needed + representation = None + if self.representation_head is not None: + representation = self.representation_head(x) + + # Decode to frame + pred_frame = None + if self.use_decoder: + pred_frame = self.decoder(x) + + if self.return_features: + return pred_frame, representation, features + else: + return pred_frame, representation + + +# Factory functions for different model sizes +def SwiftFormerTemporal_XS(num_frames=3, **kwargs): + return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs) + +def SwiftFormerTemporal_S(num_frames=3, **kwargs): + return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs) + +def SwiftFormerTemporal_L1(num_frames=3, **kwargs): + return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs) + +def SwiftFormerTemporal_L3(num_frames=3, **kwargs): + return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs) \ No newline at end of file diff --git a/test_model.py b/test_model.py new file mode 100644 index 0000000..0caab88 --- /dev/null +++ b/test_model.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +""" +Test script for SwiftFormerTemporal model +""" +import torch +import sys +import os + +# Add current directory to path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from models.swiftformer_temporal import SwiftFormerTemporal_XS + +def test_model(): + print("Testing SwiftFormerTemporal model...") + + # Create model + model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True) + print(f'Model created: {model.__class__.__name__}') + print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}') + + # Test forward pass + batch_size = 2 + num_frames = 3 + height = width = 224 + x = torch.randn(batch_size, 3 * num_frames, height, width) + + print(f'\nInput shape: {x.shape}') + + with torch.no_grad(): + pred_frame, representation = model(x) + + print(f'Predicted frame shape: {pred_frame.shape}') + print(f'Representation shape: {representation.shape if representation is not None else "None"}') + + # Check output ranges + print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]') + + # Test loss function + from util.frame_losses import MultiTaskLoss + criterion = MultiTaskLoss() + target = torch.randn_like(pred_frame) + temporal_indices = torch.tensor([3, 3], dtype=torch.long) + + loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices) + print(f'\nLoss test:') + for k, v in loss_dict.items(): + print(f' {k}: {v:.4f}') + + print('\nAll tests passed!') + return True + +if __name__ == '__main__': + try: + test_model() + except Exception as e: + print(f'Test failed with error: {e}') + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file diff --git a/util/frame_losses.py b/util/frame_losses.py new file mode 100644 index 0000000..e27cc05 --- /dev/null +++ b/util/frame_losses.py @@ -0,0 +1,182 @@ +""" +Loss functions for frame prediction and representation learning +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class SSIMLoss(nn.Module): + """ + Structural Similarity Index Measure Loss + Based on: https://github.com/Po-Hsun-Su/pytorch-ssim + """ + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 3 + self.window = self.create_window(window_size, self.channel) + + def create_window(self, window_size, channel): + def gaussian(window_size, sigma): + gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + def forward(self, img1, img2): + # Ensure window is on correct device + if self.window.device != img1.device: + self.window = self.window.to(img1.device) + + mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel) + mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if self.size_average: + return 1 - ssim_map.mean() + else: + return 1 - ssim_map.mean(1).mean(1).mean(1) + + +class FramePredictionLoss(nn.Module): + """ + Combined loss for frame prediction + """ + def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True): + super().__init__() + self.l1_weight = l1_weight + self.ssim_weight = ssim_weight + self.use_ssim = use_ssim + + self.l1_loss = nn.L1Loss() + if use_ssim: + self.ssim_loss = SSIMLoss() + + def forward(self, pred, target): + """ + Args: + pred: predicted frame [B, 3, H, W] in range [-1, 1] + target: target frame [B, 3, H, W] in range [-1, 1] + Returns: + total_loss, loss_dict + """ + loss_dict = {} + + # L1 loss + l1_loss = self.l1_loss(pred, target) + loss_dict['l1'] = l1_loss + total_loss = self.l1_weight * l1_loss + + # SSIM loss + if self.use_ssim: + ssim_loss = self.ssim_loss(pred, target) + loss_dict['ssim'] = ssim_loss + total_loss += self.ssim_weight * ssim_loss + + loss_dict['total'] = total_loss + return total_loss, loss_dict + + +class ContrastiveLoss(nn.Module): + """ + Contrastive loss for representation learning + Positive pairs: representations from adjacent frames + Negative pairs: representations from distant frames + """ + def __init__(self, temperature=0.1, margin=1.0): + super().__init__() + self.temperature = temperature + self.margin = margin + self.cosine_similarity = nn.CosineSimilarity(dim=-1) + + def forward(self, representations, temporal_indices): + """ + Args: + representations: [B, D] representation vectors + temporal_indices: [B] temporal indices of each sample + Returns: + contrastive_loss + """ + batch_size = representations.size(0) + + # Compute similarity matrix + sim_matrix = torch.matmul(representations, representations.T) / self.temperature + + # Create positive mask (adjacent frames) + indices_expanded = temporal_indices.unsqueeze(0) + diff = torch.abs(indices_expanded - indices_expanded.T) + positive_mask = (diff == 1).float() + + # Create negative mask (distant frames) + negative_mask = (diff > 2).float() + + # Positive loss + pos_sim = sim_matrix * positive_mask + pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8) + pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8) + + # Negative loss (push apart) + neg_sim = sim_matrix * negative_mask + neg_loss = torch.relu(neg_sim - self.margin).mean() + + return pos_loss + 0.1 * neg_loss + + +class MultiTaskLoss(nn.Module): + """ + Multi-task loss combining frame prediction and representation learning + """ + def __init__(self, frame_weight=1.0, contrastive_weight=0.1, + l1_weight=1.0, ssim_weight=0.1, use_contrastive=True): + super().__init__() + self.frame_weight = frame_weight + self.contrastive_weight = contrastive_weight + self.use_contrastive = use_contrastive + + self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight) + if use_contrastive: + self.contrastive_loss = ContrastiveLoss() + + def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): + """ + Args: + pred_frame: predicted frame [B, 3, H, W] + target_frame: target frame [B, 3, H, W] + representations: [B, D] representation vectors (optional) + temporal_indices: [B] temporal indices (optional) + Returns: + total_loss, loss_dict + """ + loss_dict = {} + + # Frame prediction loss + frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame) + loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()}) + total_loss = self.frame_weight * frame_loss + + # Contrastive loss (if representations provided) + if self.use_contrastive and representations is not None and temporal_indices is not None: + contrastive_loss = self.contrastive_loss(representations, temporal_indices) + loss_dict['contrastive'] = contrastive_loss + total_loss += self.contrastive_weight * contrastive_loss + + loss_dict['total'] = total_loss + return total_loss, loss_dict \ No newline at end of file diff --git a/util/video_dataset.py b/util/video_dataset.py new file mode 100644 index 0000000..8b6d57f --- /dev/null +++ b/util/video_dataset.py @@ -0,0 +1,209 @@ +""" +Video frame dataset for temporal self-supervised learning +""" +import os +import random +from pathlib import Path +from typing import Optional, Tuple, List + +import torch +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +import numpy as np + + +class VideoFrameDataset(Dataset): + """ + Dataset for loading consecutive frames from videos for frame prediction. + + Assumes directory structure: + dataset_root/ + video1/ + frame_0001.jpg + frame_0002.jpg + ... + video2/ + ... + """ + def __init__(self, + root_dir: str, + num_frames: int = 3, + frame_size: int = 224, + is_train: bool = True, + max_interval: int = 1, + transform=None): + """ + Args: + root_dir: Root directory containing video folders + num_frames: Number of input frames (T) + frame_size: Size to resize frames to + is_train: Whether this is training set (affects augmentation) + max_interval: Maximum interval between consecutive frames + transform: Optional custom transform + """ + self.root_dir = Path(root_dir) + self.num_frames = num_frames + self.frame_size = frame_size + self.is_train = is_train + self.max_interval = max_interval + + # Collect all video folders + self.video_folders = [] + for item in self.root_dir.iterdir(): + if item.is_dir(): + self.video_folders.append(item) + + if len(self.video_folders) == 0: + raise ValueError(f"No video folders found in {root_dir}") + + # Build frame index: list of (video_idx, start_frame_idx) + self.frame_indices = [] + for video_idx, video_folder in enumerate(self.video_folders): + # Get all frame files + frame_files = sorted([f for f in video_folder.iterdir() + if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) + + if len(frame_files) < num_frames + 1: + continue # Skip videos with insufficient frames + + # Add all possible starting positions + for start_idx in range(len(frame_files) - num_frames): + self.frame_indices.append((video_idx, start_idx)) + + if len(self.frame_indices) == 0: + raise ValueError("No valid frame sequences found in dataset") + + # Default transforms + if transform is None: + self.transform = self._default_transform() + else: + self.transform = transform + + # Normalization (ImageNet stats) + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + + def _default_transform(self): + """Default transform with augmentation for training""" + if self.is_train: + return transforms.Compose([ + transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), + ]) + else: + return transforms.Compose([ + transforms.Resize(int(self.frame_size * 1.14)), + transforms.CenterCrop(self.frame_size), + ]) + + def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: + """Load a single frame as PIL Image""" + video_folder = self.video_folders[video_idx] + frame_files = sorted([f for f in video_folder.iterdir() + if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) + frame_path = frame_files[frame_idx] + return Image.open(frame_path).convert('RGB') + + def __len__(self) -> int: + return len(self.frame_indices) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + input_frames: [3 * num_frames, H, W] concatenated input frames + target_frame: [3, H, W] target frame to predict + temporal_idx: temporal index of target frame (for contrastive loss) + """ + video_idx, start_idx = self.frame_indices[idx] + + # Determine frame interval (for temporal augmentation) + interval = random.randint(1, self.max_interval) if self.is_train else 1 + + # Load input frames + input_frames = [] + for i in range(self.num_frames): + frame_idx = start_idx + i * interval + frame = self._load_frame(video_idx, frame_idx) + + # Apply transform (same for all frames in sequence) + if self.transform: + frame = self.transform(frame) + + input_frames.append(frame) + + # Load target frame (next frame after input sequence) + target_idx = start_idx + self.num_frames * interval + target_frame = self._load_frame(video_idx, target_idx) + if self.transform: + target_frame = self.transform(target_frame) + + # Convert to tensors and normalize + input_tensors = [] + for frame in input_frames: + tensor = transforms.ToTensor()(frame) + tensor = self.normalize(tensor) + input_tensors.append(tensor) + + target_tensor = transforms.ToTensor()(target_frame) + target_tensor = self.normalize(target_tensor) + + # Concatenate input frames along channel dimension + input_concatenated = torch.cat(input_tensors, dim=0) + + # Temporal index (for contrastive loss) + temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) + + return input_concatenated, target_tensor, temporal_idx + + +class SyntheticVideoDataset(Dataset): + """ + Synthetic dataset for testing - generates random frames + """ + def __init__(self, + num_samples: int = 1000, + num_frames: int = 3, + frame_size: int = 224, + is_train: bool = True): + self.num_samples = num_samples + self.num_frames = num_frames + self.frame_size = frame_size + self.is_train = is_train + + # Normalization + self.normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # Generate random "frames" (noise with temporal correlation) + input_frames = [] + prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 + + for i in range(self.num_frames): + # Add some temporal correlation + frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 + frame = torch.clamp(frame, -1, 1) + input_frames.append(self.normalize(frame)) + prev_frame = frame + + # Target frame (next in sequence) + target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 + target_frame = torch.clamp(target_frame, -1, 1) + target_tensor = self.normalize(target_frame) + + # Concatenate inputs + input_concatenated = torch.cat(input_frames, dim=0) + + # Temporal index + temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) + + return input_concatenated, target_tensor, temporal_idx \ No newline at end of file