初步可跑通,但loss计算有问题,不收敛

This commit is contained in:
2026-01-08 09:43:23 +08:00
parent efd76bccd2
commit f7601e9170
11 changed files with 656 additions and 63 deletions

View File

@@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset):
else:
self.transform = transform
# Normalization (ImageNet stats)
# Normalization for Y channel (single channel)
# Compute average of ImageNet RGB means and stds
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
mean=[y_mean],
std=[y_std]
)
def _default_transform(self):
@@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
@@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset):
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
# Convert to tensors, normalize, and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
tensor = transforms.ToTensor()(frame) # [3, H, W]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray)
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset):
@@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset):
self.frame_size = frame_size
self.is_train = is_train
# Normalization
# Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
mean=[y_mean],
std=[y_std]
)
def __len__(self):