""" Loss functions for frame prediction and representation learning """ import torch import torch.nn as nn import torch.nn.functional as F import math class SSIMLoss(nn.Module): """ Structural Similarity Index Measure Loss Based on: https://github.com/Po-Hsun-Su/pytorch-ssim """ def __init__(self, window_size=11, size_average=True): super().__init__() self.window_size = window_size self.size_average = size_average self.channel = 3 self.window = self.create_window(window_size, self.channel) def create_window(self, window_size, channel): def gaussian(window_size, sigma): gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() return window def forward(self, img1, img2): # Ensure window is on correct device if self.window.device != img1.device: self.window = self.window.to(img1.device) mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel) mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if self.size_average: return 1 - ssim_map.mean() else: return 1 - ssim_map.mean(1).mean(1).mean(1) class FramePredictionLoss(nn.Module): """ Combined loss for frame prediction """ def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True): super().__init__() self.l1_weight = l1_weight self.ssim_weight = ssim_weight self.use_ssim = use_ssim self.l1_loss = nn.L1Loss() if use_ssim: self.ssim_loss = SSIMLoss() def forward(self, pred, target): """ Args: pred: predicted frame [B, 3, H, W] in range [-1, 1] target: target frame [B, 3, H, W] in range [-1, 1] Returns: total_loss, loss_dict """ loss_dict = {} # L1 loss l1_loss = self.l1_loss(pred, target) loss_dict['l1'] = l1_loss total_loss = self.l1_weight * l1_loss # SSIM loss if self.use_ssim: ssim_loss = self.ssim_loss(pred, target) loss_dict['ssim'] = ssim_loss total_loss += self.ssim_weight * ssim_loss loss_dict['total'] = total_loss return total_loss, loss_dict class ContrastiveLoss(nn.Module): """ Contrastive loss for representation learning Positive pairs: representations from adjacent frames Negative pairs: representations from distant frames """ def __init__(self, temperature=0.1, margin=1.0): super().__init__() self.temperature = temperature self.margin = margin self.cosine_similarity = nn.CosineSimilarity(dim=-1) def forward(self, representations, temporal_indices): """ Args: representations: [B, D] representation vectors temporal_indices: [B] temporal indices of each sample Returns: contrastive_loss """ batch_size = representations.size(0) # Compute similarity matrix sim_matrix = torch.matmul(representations, representations.T) / self.temperature # Create positive mask (adjacent frames) indices_expanded = temporal_indices.unsqueeze(0) diff = torch.abs(indices_expanded - indices_expanded.T) positive_mask = (diff == 1).float() # Create negative mask (distant frames) negative_mask = (diff > 2).float() # Positive loss pos_sim = sim_matrix * positive_mask pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8) pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8) # Negative loss (push apart) neg_sim = sim_matrix * negative_mask neg_loss = torch.relu(neg_sim - self.margin).mean() return pos_loss + 0.1 * neg_loss class MultiTaskLoss(nn.Module): """ Multi-task loss combining frame prediction and representation learning """ def __init__(self, frame_weight=1.0, contrastive_weight=0.1, l1_weight=1.0, ssim_weight=0.1, use_contrastive=True): super().__init__() self.frame_weight = frame_weight self.contrastive_weight = contrastive_weight self.use_contrastive = use_contrastive self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight) if use_contrastive: self.contrastive_loss = ContrastiveLoss() def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): """ Args: pred_frame: predicted frame [B, 3, H, W] target_frame: target frame [B, 3, H, W] representations: [B, D] representation vectors (optional) temporal_indices: [B] temporal indices (optional) Returns: total_loss, loss_dict """ loss_dict = {} # Frame prediction loss frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame) loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()}) total_loss = self.frame_weight * frame_loss # Contrastive loss (if representations provided) if self.use_contrastive and representations is not None and temporal_indices is not None: contrastive_loss = self.contrastive_loss(representations, temporal_indices) loss_dict['contrastive'] = contrastive_loss total_loss += self.contrastive_weight * contrastive_loss loss_dict['total'] = total_loss return total_loss, loss_dict