修改梯度裁剪的恶性bug,当前能进行训练,但是无论是否使用跳连接,预测帧总是输出对称的的效果,mse收敛到0.10
This commit is contained in:
@@ -20,16 +20,12 @@ from util import *
|
||||
from models import *
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
# from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TENSORBOARD_AVAILABLE = False
|
||||
|
||||
@@ -57,7 +53,7 @@ def get_args_parser():
|
||||
help='Use representation head for pose/velocity prediction')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='Dimension of representation vector')
|
||||
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
|
||||
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
@@ -328,7 +324,7 @@ def main(args):
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
|
||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
@@ -368,7 +364,7 @@ def main(args):
|
||||
|
||||
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
|
||||
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
|
||||
global_step=0, args=None, **kwargs):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
@@ -403,7 +399,6 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 在反向传播前保存梯度用于诊断
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters())
|
||||
|
||||
@@ -426,14 +421,14 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(grad_norm=total_grad_norm)
|
||||
|
||||
# 每50个批次打印一次BatchNorm统计
|
||||
# # 每50个批次打印一次BatchNorm统计
|
||||
if batch_idx % 50 == 0:
|
||||
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||
# 检查一个BatchNorm层的运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
break
|
||||
# # 检查一个BatchNorm层的运行统计
|
||||
# for name, module in model.named_modules():
|
||||
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
# break
|
||||
|
||||
# Log to TensorBoard
|
||||
if writer is not None:
|
||||
@@ -520,21 +515,21 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
metric_logger.update(target_mean=target_mean)
|
||||
metric_logger.update(target_std=target_std)
|
||||
|
||||
# 第一个批次打印详细诊断信息
|
||||
if batch_idx == 0:
|
||||
print(f"[评估诊断] 批次 0:")
|
||||
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
# # 第一个批次打印详细诊断信息
|
||||
# if batch_idx == 0:
|
||||
# print(f"[评估诊断] 批次 0:")
|
||||
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
|
||||
# 检查BatchNorm运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
if module.running_var[0].item() < 1e-6:
|
||||
print(f" 警告: BatchNorm运行方差接近零!")
|
||||
break
|
||||
# # 检查BatchNorm运行统计
|
||||
# for name, module in model.named_modules():
|
||||
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
# if module.running_var[0].item() < 1e-6:
|
||||
# print(f" 警告: BatchNorm运行方差接近零!")
|
||||
# break
|
||||
|
||||
# Update metrics
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
Loss functions for frame prediction and representation learning
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
class SSIMLoss(nn.Module):
|
||||
"""
|
||||
Structural Similarity Index Measure Loss
|
||||
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
|
||||
"""
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 3
|
||||
self.window = self.create_window(window_size, self.channel)
|
||||
|
||||
def create_window(self, window_size, channel):
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
return gauss/gauss.sum()
|
||||
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
return window
|
||||
|
||||
def forward(self, img1, img2):
|
||||
# Ensure window is on correct device
|
||||
if self.window.device != img1.device:
|
||||
self.window = self.window.to(img1.device)
|
||||
|
||||
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
|
||||
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if self.size_average:
|
||||
return 1 - ssim_map.mean()
|
||||
else:
|
||||
return 1 - ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class FramePredictionLoss(nn.Module):
|
||||
"""
|
||||
Combined loss for frame prediction
|
||||
"""
|
||||
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
|
||||
super().__init__()
|
||||
self.l1_weight = l1_weight
|
||||
self.ssim_weight = ssim_weight
|
||||
self.use_ssim = use_ssim
|
||||
|
||||
self.l1_loss = nn.L1Loss()
|
||||
if use_ssim:
|
||||
self.ssim_loss = SSIMLoss()
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""
|
||||
Args:
|
||||
pred: predicted frame [B, 3, H, W] in range [-1, 1]
|
||||
target: target frame [B, 3, H, W] in range [-1, 1]
|
||||
Returns:
|
||||
total_loss, loss_dict
|
||||
"""
|
||||
loss_dict = {}
|
||||
|
||||
# L1 loss
|
||||
l1_loss = self.l1_loss(pred, target)
|
||||
loss_dict['l1'] = l1_loss
|
||||
total_loss = self.l1_weight * l1_loss
|
||||
|
||||
# SSIM loss
|
||||
if self.use_ssim:
|
||||
ssim_loss = self.ssim_loss(pred, target)
|
||||
loss_dict['ssim'] = ssim_loss
|
||||
total_loss += self.ssim_weight * ssim_loss
|
||||
|
||||
loss_dict['total'] = total_loss
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Contrastive loss for representation learning
|
||||
Positive pairs: representations from adjacent frames
|
||||
Negative pairs: representations from distant frames
|
||||
"""
|
||||
def __init__(self, temperature=0.1, margin=1.0):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.margin = margin
|
||||
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
|
||||
|
||||
def forward(self, representations, temporal_indices):
|
||||
"""
|
||||
Args:
|
||||
representations: [B, D] representation vectors
|
||||
temporal_indices: [B] temporal indices of each sample
|
||||
Returns:
|
||||
contrastive_loss
|
||||
"""
|
||||
batch_size = representations.size(0)
|
||||
|
||||
# Compute similarity matrix
|
||||
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
|
||||
|
||||
# Create positive mask (adjacent frames)
|
||||
indices_expanded = temporal_indices.unsqueeze(0)
|
||||
diff = torch.abs(indices_expanded - indices_expanded.T)
|
||||
positive_mask = (diff == 1).float()
|
||||
|
||||
# Create negative mask (distant frames)
|
||||
negative_mask = (diff > 2).float()
|
||||
|
||||
# Positive loss
|
||||
pos_sim = sim_matrix * positive_mask
|
||||
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
|
||||
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
|
||||
|
||||
# Negative loss (push apart)
|
||||
neg_sim = sim_matrix * negative_mask
|
||||
neg_loss = torch.relu(neg_sim - self.margin).mean()
|
||||
|
||||
return pos_loss + 0.1 * neg_loss
|
||||
|
||||
|
||||
class MultiTaskLoss(nn.Module):
|
||||
"""
|
||||
Multi-task loss combining frame prediction and representation learning
|
||||
"""
|
||||
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
|
||||
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
|
||||
super().__init__()
|
||||
self.frame_weight = frame_weight
|
||||
self.contrastive_weight = contrastive_weight
|
||||
self.use_contrastive = use_contrastive
|
||||
|
||||
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
|
||||
if use_contrastive:
|
||||
self.contrastive_loss = ContrastiveLoss()
|
||||
|
||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
||||
"""
|
||||
Args:
|
||||
pred_frame: predicted frame [B, 3, H, W]
|
||||
target_frame: target frame [B, 3, H, W]
|
||||
representations: [B, D] representation vectors (optional)
|
||||
temporal_indices: [B] temporal indices (optional)
|
||||
Returns:
|
||||
total_loss, loss_dict
|
||||
"""
|
||||
loss_dict = {}
|
||||
|
||||
# Frame prediction loss
|
||||
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
|
||||
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
|
||||
total_loss = self.frame_weight * frame_loss
|
||||
|
||||
# Contrastive loss (if representations provided)
|
||||
if self.use_contrastive and representations is not None and temporal_indices is not None:
|
||||
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
|
||||
loss_dict['contrastive'] = contrastive_loss
|
||||
total_loss += self.contrastive_weight * contrastive_loss
|
||||
|
||||
loss_dict['total'] = total_loss
|
||||
return total_loss, loss_dict
|
||||
Reference in New Issue
Block a user