完善了跳连接,在上decode块后增加特征精炼层,未测效果
This commit is contained in:
@@ -57,6 +57,7 @@ def get_args_parser():
|
||||
help='Use representation head for pose/velocity prediction')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='Dimension of representation vector')
|
||||
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
@@ -77,7 +78,7 @@ def get_args_parser():
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
||||
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
|
||||
help='learning rate (default: 1e-3)')
|
||||
|
||||
# Learning rate schedule parameters (required by timm's create_scheduler)
|
||||
@@ -89,7 +90,7 @@ def get_args_parser():
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-6)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
@@ -212,6 +213,7 @@ def main(args):
|
||||
'num_frames': args.num_frames,
|
||||
'use_representation_head': args.use_representation_head,
|
||||
'representation_dim': args.representation_dim,
|
||||
'use_skip': args.use_skip,
|
||||
}
|
||||
|
||||
if args.model == 'SwiftFormerTemporal_XS':
|
||||
@@ -373,6 +375,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = f'Epoch: [{epoch}]'
|
||||
print_freq = 10
|
||||
|
||||
# 添加诊断指标
|
||||
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
|
||||
metric_logger.log_every(data_loader, print_freq, header)):
|
||||
@@ -382,7 +389,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
@@ -395,6 +402,8 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
raise ValueError(f"Loss is {loss_value}")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 在反向传播前保存梯度用于诊断
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters())
|
||||
|
||||
@@ -402,6 +411,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
# 计算诊断指标
|
||||
pred_mean = pred_frames.mean().item()
|
||||
pred_std = pred_frames.std().item()
|
||||
|
||||
# 计算梯度范数
|
||||
total_grad_norm = 0.0
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
total_grad_norm += param.grad.norm().item()
|
||||
|
||||
# 记录诊断指标
|
||||
metric_logger.update(pred_mean=pred_mean)
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(grad_norm=total_grad_norm)
|
||||
|
||||
# 每50个批次打印一次BatchNorm统计
|
||||
if batch_idx % 50 == 0:
|
||||
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||
# 检查一个BatchNorm层的运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
break
|
||||
|
||||
# Log to TensorBoard
|
||||
if writer is not None:
|
||||
# Log scalar metrics every iteration
|
||||
@@ -415,6 +448,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
else:
|
||||
writer.add_scalar(f'train/{k}', v, global_step)
|
||||
|
||||
# Log diagnostic metrics
|
||||
writer.add_scalar('train/pred_mean', pred_mean, global_step)
|
||||
writer.add_scalar('train/pred_std', pred_std, global_step)
|
||||
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
|
||||
|
||||
# Log images periodically
|
||||
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||
with torch.no_grad():
|
||||
@@ -450,20 +488,54 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
model.eval()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
# 添加诊断指标
|
||||
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
|
||||
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
||||
input_frames = input_frames.to(device, non_blocking=True)
|
||||
target_frames = target_frames.to(device, non_blocking=True)
|
||||
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||
|
||||
# Compute output
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
representations, temporal_indices
|
||||
)
|
||||
|
||||
# 计算诊断指标
|
||||
pred_mean = pred_frames.mean().item()
|
||||
pred_std = pred_frames.std().item()
|
||||
target_mean = target_frames.mean().item()
|
||||
target_std = target_frames.std().item()
|
||||
|
||||
# 更新诊断指标
|
||||
metric_logger.update(pred_mean=pred_mean)
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(target_mean=target_mean)
|
||||
metric_logger.update(target_std=target_std)
|
||||
|
||||
# 第一个批次打印详细诊断信息
|
||||
if batch_idx == 0:
|
||||
print(f"[评估诊断] 批次 0:")
|
||||
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
|
||||
# 检查BatchNorm运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
if module.running_var[0].item() < 1e-6:
|
||||
print(f" 警告: BatchNorm运行方差接近零!")
|
||||
break
|
||||
|
||||
# Update metrics
|
||||
metric_logger.update(loss=loss.item())
|
||||
for k, v in loss_dict.items():
|
||||
|
||||
Reference in New Issue
Block a user