修改梯度裁剪的恶性bug,当前能进行训练,但是无论是否使用跳连接,预测帧总是输出对称的的效果,mse收敛到0.10

This commit is contained in:
2026-01-11 10:50:11 +08:00
parent 12de74f130
commit c5502cc87c
3 changed files with 25 additions and 212 deletions

View File

@@ -20,18 +20,14 @@ from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss
# from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
try:
from tensorboardX import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
TENSORBOARD_AVAILABLE = False
def get_args_parser():
@@ -57,7 +53,7 @@ def get_args_parser():
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
@@ -328,7 +324,7 @@ def main(args):
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
@@ -368,7 +364,7 @@ def main(args):
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
@@ -403,7 +399,6 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
optimizer.zero_grad()
# 在反向传播前保存梯度用于诊断
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
@@ -426,14 +421,14 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# 每50个批次打印一次BatchNorm统计
# # 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# 检查一个BatchNorm层的运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
break
# # 检查一个BatchNorm层的运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# break
# Log to TensorBoard
if writer is not None:
@@ -520,21 +515,21 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# 第一个批次打印详细诊断信息
if batch_idx == 0:
print(f"[评估诊断] 批次 0:")
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# # 第一个批次打印详细诊断信息
# if batch_idx == 0:
# print(f"[评估诊断] 批次 0:")
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# 检查BatchNorm运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
if module.running_var[0].item() < 1e-6:
print(f" 警告: BatchNorm运行方差接近零!")
break
# # 检查BatchNorm运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# if module.running_var[0].item() < 1e-6:
# print(f" 警告: BatchNorm运行方差接近零!")
# break
# Update metrics
metric_logger.update(loss=loss.item())