Files
asmo_vhead/util/video_dataset.py

209 lines
7.5 KiB
Python
Raw Normal View History

"""
Video frame dataset for temporal self-supervised learning
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, List
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class VideoFrameDataset(Dataset):
"""
Dataset for loading consecutive frames from videos for frame prediction.
Assumes directory structure:
dataset_root/
video1/
frame_0001.jpg
frame_0002.jpg
...
video2/
...
"""
def __init__(self,
root_dir: str,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True,
max_interval: int = 1,
transform=None):
"""
Args:
root_dir: Root directory containing video folders
num_frames: Number of input frames (T)
frame_size: Size to resize frames to
is_train: Whether this is training set (affects augmentation)
max_interval: Maximum interval between consecutive frames
transform: Optional custom transform
"""
self.root_dir = Path(root_dir)
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# Collect all video folders
self.video_folders = []
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
raise ValueError("No valid frame sequences found in dataset")
# Default transforms
if transform is None:
self.transform = self._default_transform()
else:
self.transform = transform
# Normalization (ImageNet stats)
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def _default_transform(self):
"""Default transform with augmentation for training"""
if self.is_train:
return transforms.Compose([
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
else:
return transforms.Compose([
transforms.Resize(int(self.frame_size * 1.14)),
transforms.CenterCrop(self.frame_size),
])
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
def __len__(self) -> int:
return len(self.frame_indices)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
# Determine frame interval (for temporal augmentation)
interval = random.randint(1, self.max_interval) if self.is_train else 1
# Load input frames
input_frames = []
for i in range(self.num_frames):
frame_idx = start_idx + i * interval
frame = self._load_frame(video_idx, frame_idx)
# Apply transform (same for all frames in sequence)
if self.transform:
frame = self.transform(frame)
input_frames.append(frame)
# Load target frame (next frame after input sequence)
target_idx = start_idx + self.num_frames * interval
target_frame = self._load_frame(video_idx, target_idx)
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# Normalization
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx