删除残差路径和shortcut,镜像问题仍存在
This commit is contained in:
@@ -45,7 +45,6 @@ def denormalize(tensor):
|
||||
# [0, 1] -> [0, 255]
|
||||
tensor = tensor * 255
|
||||
return tensor.clamp(0, 255)
|
||||
# return tensor
|
||||
|
||||
def minmax_denormalize(tensor):
|
||||
tensor_min = tensor.min()
|
||||
@@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
|
||||
if target_np.ndim == 3:
|
||||
target_np = target_np.squeeze(0)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||
# if debug:
|
||||
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||
|
||||
# 计算MSE - 修复错误的tmp公式
|
||||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
# 正确公式: 直接计算像素差的平方
|
||||
mse = np.mean((pred_np - target_np) ** 2)
|
||||
|
||||
# 同时计算错误公式的MSE用于对比
|
||||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
wrong_mse = np.mean(tmp**2)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
||||
|
||||
# 计算SSIM (数据范围0-255)
|
||||
data_range = 255.0
|
||||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||
|
||||
# 计算PSNR
|
||||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||
|
||||
return mse, ssim_value, psnr_value
|
||||
@@ -146,16 +133,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||||
#debug print
|
||||
print(target_frame)
|
||||
print(pred_frame)
|
||||
|
||||
# # debug print - 改进为更有信息量的输出
|
||||
# if isinstance(pred_frame, np.ndarray):
|
||||
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||||
# # 检查是否有大量值在127.5附近
|
||||
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
||||
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
||||
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
||||
# else:
|
||||
# print(pred_frame)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
@@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
|
||||
|
||||
# 对第一个样本启用调试
|
||||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||
if debug_mode:
|
||||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||
# if debug_mode:
|
||||
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||
|
||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
|
||||
|
||||
total_mse += mse
|
||||
total_ssim += ssim_value
|
||||
|
||||
@@ -43,7 +43,7 @@ def get_args_parser():
|
||||
help='Number of input frames (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='Input frame size')
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
parser.add_argument('--max-interval', default=10, type=int,
|
||||
help='Maximum interval between consecutive frames')
|
||||
|
||||
# Model parameters
|
||||
@@ -121,7 +121,7 @@ def get_args_parser():
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--num-workers', default=4, type=int)
|
||||
parser.add_argument('--num-workers', default=16, type=int)
|
||||
parser.add_argument('--pin-mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader')
|
||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||
@@ -264,7 +264,7 @@ def main(args):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.resume, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||
@@ -308,7 +308,7 @@ def main(args):
|
||||
|
||||
train_stats, global_step = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
|
||||
model_ema=model_ema, writer=writer,
|
||||
global_step=global_step, args=args
|
||||
)
|
||||
@@ -356,7 +356,7 @@ def main(args):
|
||||
|
||||
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
|
||||
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
|
||||
global_step=0, args=None, **kwargs):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
|
||||
@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
||||
"""Upsampling block for frame prediction decoder without residual connections"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
@@ -31,28 +31,6 @@ class DecoderBlock(nn.Module):
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if in_channels != out_channels or stride != 1:
|
||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
else:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
# 使用ReLU激活函数
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
@@ -67,13 +45,6 @@ class DecoderBlock(nn.Module):
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
# shortcut现在是Sequential,需要初始化其中的卷积层
|
||||
for module in self.shortcut:
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
# 初始化BN层(使用默认初始化)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@@ -81,8 +52,6 @@ class DecoderBlock(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.bn1(x)
|
||||
@@ -94,9 +63,6 @@ class DecoderBlock(nn.Module):
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
@@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module):
|
||||
"""Improved decoder for frame prediction"""
|
||||
def __init__(self, embed_dims, output_channels=1):
|
||||
super().__init__()
|
||||
# Reverse the embed_dims for decoder
|
||||
decoder_dims = embed_dims[::-1]
|
||||
# Define decoder dimensions independently (no skip connections)
|
||||
start_dim = embed_dims[-1]
|
||||
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
# 调整顺序:将stride=4放在倒数第二的位置
|
||||
# 第一个block:stride=2 (220 -> 112)
|
||||
# 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第二个block:stride=2 (112 -> 56)
|
||||
# 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第三个block:stride=2 (56 -> 48)
|
||||
# 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:stride=4 (48 -> 64),放在倒数第二的位置
|
||||
# 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[3], 64,
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||
@@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module):
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user