更新归一化方式,当前直接映射,不利用均值标准差进行标准化

This commit is contained in:
2026-01-08 16:10:24 +08:00
parent f7601e9170
commit 500c2eb18f
3 changed files with 89 additions and 74 deletions

View File

@@ -11,7 +11,7 @@ shift 2
# Default parameters # Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"} MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32} BATCH_SIZE=${BATCH_SIZE:-256}
EPOCHS=${EPOCHS:-100} EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3} LR=${LR:-1e-3}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"} OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}

View File

@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import * from util import *
from models import * from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard # Try to import TensorBoard
@@ -47,7 +47,7 @@ def get_args_parser():
help='Number of input frames (T)') help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int, parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size') help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int, parser.add_argument('--max-interval', default=4, type=int,
help='Maximum interval between consecutive frames') help='Maximum interval between consecutive frames')
# Model parameters # Model parameters
@@ -109,10 +109,10 @@ def get_args_parser():
help='Weight for frame prediction loss') help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1, parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss') help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0, # parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss') # help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1, # parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss') # help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true', parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss') help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true', parser.add_argument('--no-ssim', action='store_true',
@@ -326,7 +326,7 @@ def main(args):
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
# Save checkpoint # Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({ utils.save_on_master({
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),

View File

@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
self.is_train = is_train self.is_train = is_train
self.max_interval = max_interval self.max_interval = max_interval
# Collect all video folders # if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = [] self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir(): for item in self.root_dir.iterdir():
if item.is_dir(): if item.is_dir():
self.video_folders.append(item) self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0: if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}") raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx) # Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = [] self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders): for video_idx, frame_files in enumerate(self.video_frame_files):
# Get all frame files # Minimum frames needed considering max interval
frame_files = sorted([f for f in video_folder.iterdir() min_frames_needed = num_frames * max_interval + 1
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) if len(frame_files) < min_frames_needed:
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames continue # Skip videos with insufficient frames
# Add all possible starting positions # Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames): # Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx)) self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0: if len(self.frame_indices) == 0:
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
else: else:
self.transform = transform self.transform = transform
# Normalization for Y channel (single channel) # Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Compute average of ImageNet RGB means and stds # Convert pixel values [0, 255] to [-1, 1]
y_mean = (0.485 + 0.456 + 0.406) / 3.0 # This matches the model's tanh output range
y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = None # We'll handle normalization manually
self.normalize = transforms.Normalize(
mean=[y_mean], # print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
std=[y_std]
)
def _default_transform(self): def _default_transform(self):
"""Default transform with augmentation for training""" """Default transform with augmentation for training"""
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image""" """Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx] frame_files = self.video_frame_files[video_idx]
frame_files = sorted([f for f in video_folder.iterdir() if frame_idx < 0 or frame_idx >= len(frame_files):
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx] frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB') return Image.open(frame_path).convert('RGB')
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
if self.transform: if self.transform:
target_frame = self.transform(target_frame) target_frame = self.transform(target_frame)
# Convert to tensors, normalize, and convert to grayscale (Y channel) # Convert to tensors and convert to grayscale (Y channel)
input_tensors = [] input_tensors = []
for frame in input_frames: for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W] tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum # Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W] gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast) # Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray) input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W] target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray) # Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension # Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
return input_concatenated, target_gray, temporal_idx return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset): # class SyntheticVideoDataset(Dataset):
""" # """
Synthetic dataset for testing - generates random frames # Synthetic dataset for testing - generates random frames
""" # """
def __init__(self, # def __init__(self,
num_samples: int = 1000, # num_samples: int = 1000,
num_frames: int = 3, # num_frames: int = 3,
frame_size: int = 224, # frame_size: int = 224,
is_train: bool = True): # is_train: bool = True):
self.num_samples = num_samples # self.num_samples = num_samples
self.num_frames = num_frames # self.num_frames = num_frames
self.frame_size = frame_size # self.frame_size = frame_size
self.is_train = is_train # self.is_train = is_train
# Normalization for Y channel (single channel) # # Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0 # y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0 # y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize( # self.normalize = transforms.Normalize(
mean=[y_mean], # mean=[y_mean],
std=[y_std] # std=[y_std]
) # )
def __len__(self): # def __len__(self):
return self.num_samples # return self.num_samples
def __getitem__(self, idx): # def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation) # # Generate random "frames" (noise with temporal correlation)
input_frames = [] # input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 # prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames): # for i in range(self.num_frames):
# Add some temporal correlation # # Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1) # frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame)) # input_frames.append(self.normalize(frame))
prev_frame = frame # prev_frame = frame
# Target frame (next in sequence) # # Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1) # target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame) # target_tensor = self.normalize(target_frame)
# Concatenate inputs # # Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0) # input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index # # Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) # temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx # return input_concatenated, target_tensor, temporal_idx