Files
asmo_vhead/util/video_dataset.py

233 lines
9.2 KiB
Python
Raw Permalink Normal View History

"""
Video frame dataset for temporal self-supervised learning
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, List
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class VideoFrameDataset(Dataset):
"""
Dataset for loading consecutive frames from videos for frame prediction.
Assumes directory structure:
dataset_root/
video1/
frame_0001.jpg
frame_0002.jpg
...
video2/
...
"""
def __init__(self,
root_dir: str,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True,
max_interval: int = 1,
transform=None):
"""
Args:
root_dir: Root directory containing video folders
num_frames: Number of input frames (T)
frame_size: Size to resize frames to
is_train: Whether this is training set (affects augmentation)
max_interval: Maximum interval between consecutive frames
transform: Optional custom transform
"""
self.root_dir = Path(root_dir)
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, frame_files in enumerate(self.video_frame_files):
# Minimum frames needed considering max interval
min_frames_needed = num_frames * max_interval + 1
if len(frame_files) < min_frames_needed:
continue # Skip videos with insufficient frames
# Add all possible starting positions
# Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
raise ValueError("No valid frame sequences found in dataset")
# Default transforms
if transform is None:
self.transform = self._default_transform()
else:
self.transform = transform
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Convert pixel values [0, 255] to [-1, 1]
# This matches the model's tanh output range
self.normalize = None # We'll handle normalization manually
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self):
"""Default transform with augmentation for training"""
if self.is_train:
return transforms.Compose([
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
else:
return transforms.Compose([
transforms.Resize(int(self.frame_size * 1.14)),
transforms.CenterCrop(self.frame_size),
])
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
frame_files = self.video_frame_files[video_idx]
if frame_idx < 0 or frame_idx >= len(frame_files):
raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
def __len__(self) -> int:
return len(self.frame_indices)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
# Determine frame interval (for temporal augmentation)
interval = random.randint(1, self.max_interval) if self.is_train else 1
# Load input frames
input_frames = []
for i in range(self.num_frames):
frame_idx = start_idx + i * interval
frame = self._load_frame(video_idx, frame_idx)
# Apply transform (same for all frames in sequence)
if self.transform:
frame = self.transform(frame)
input_frames.append(frame)
# Load target frame (next frame after input sequence)
target_idx = start_idx + self.num_frames * interval
target_frame = self._load_frame(video_idx, target_idx)
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_gray, temporal_idx
# class SyntheticVideoDataset(Dataset):
# """
# Synthetic dataset for testing - generates random frames
# """
# def __init__(self,
# num_samples: int = 1000,
# num_frames: int = 3,
# frame_size: int = 224,
# is_train: bool = True):
# self.num_samples = num_samples
# self.num_frames = num_frames
# self.frame_size = frame_size
# self.is_train = is_train
# # Normalization for Y channel (single channel)
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
# y_std = (0.229 + 0.224 + 0.225) / 3.0
# self.normalize = transforms.Normalize(
# mean=[y_mean],
# std=[y_std]
# )
# def __len__(self):
# return self.num_samples
# def __getitem__(self, idx):
# # Generate random "frames" (noise with temporal correlation)
# input_frames = []
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
# for i in range(self.num_frames):
# # Add some temporal correlation
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# frame = torch.clamp(frame, -1, 1)
# input_frames.append(self.normalize(frame))
# prev_frame = frame
# # Target frame (next in sequence)
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# target_frame = torch.clamp(target_frame, -1, 1)
# target_tensor = self.normalize(target_frame)
# # Concatenate inputs
# input_concatenated = torch.cat(input_frames, dim=0)
# # Temporal index
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
# return input_concatenated, target_tensor, temporal_idx