From 543beefa2a7d780affa96665a1f3a1be32d96231 Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Fri, 16 Jan 2026 15:21:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=AE=8B=E5=B7=AE=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E5=92=8Cshortcut=EF=BC=8C=E9=95=9C=E5=83=8F=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E4=BB=8D=E5=AD=98=E5=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- evaluate_temporal.py | 43 +++++++--------------------- main_temporal.py | 10 +++---- models/swiftformer_temporal.py | 52 ++++++---------------------------- 3 files changed, 24 insertions(+), 81 deletions(-) diff --git a/evaluate_temporal.py b/evaluate_temporal.py index 9e59292..32f1045 100644 --- a/evaluate_temporal.py +++ b/evaluate_temporal.py @@ -45,7 +45,6 @@ def denormalize(tensor): # [0, 1] -> [0, 255] tensor = tensor * 255 return tensor.clamp(0, 255) - # return tensor def minmax_denormalize(tensor): tensor_min = tensor.min() @@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False): if target_np.ndim == 3: target_np = target_np.squeeze(0) - if debug: - print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}") - print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}") - print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}") + # if debug: + # print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}") + # print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}") + # print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}") - # 计算MSE - 修复错误的tmp公式 - # 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2 - # 正确公式: 直接计算像素差的平方 mse = np.mean((pred_np - target_np) ** 2) - # 同时计算错误公式的MSE用于对比 - tmp = 1 - (pred_np - target_np) / 255 * 2 - wrong_mse = np.mean(tmp**2) - - if debug: - print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}") - - # 计算SSIM (数据范围0-255) data_range = 255.0 ssim_value = ssim(pred_np, target_np, data_range=data_range) - # 计算PSNR psnr_value = psnr(target_np, pred_np, data_range=data_range) return mse, ssim_value, psnr_value @@ -146,16 +133,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path, #debug print print(target_frame) print(pred_frame) - - # # debug print - 改进为更有信息量的输出 - # if isinstance(pred_frame, np.ndarray): - # print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") - # # 检查是否有大量值在127.5附近 - # mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0 - # percent_near_127_5 = np.mean(mask_near_127_5) * 100 - # print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%") - # else: - # print(pred_frame) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') @@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args): # 对第一个样本启用调试 debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0) - if debug_mode: - print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}") - print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}") - print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}") - print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}") + # if debug_mode: + # print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}") + # print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}") + # print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}") + # print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}") - mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode) + mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False) total_mse += mse total_ssim += ssim_value diff --git a/main_temporal.py b/main_temporal.py index 90328e8..79efeaf 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -43,7 +43,7 @@ def get_args_parser(): help='Number of input frames (T)') parser.add_argument('--frame-size', default=224, type=int, help='Input frame size') - parser.add_argument('--max-interval', default=4, type=int, + parser.add_argument('--max-interval', default=10, type=int, help='Maximum interval between consecutive frames') # Model parameters @@ -121,7 +121,7 @@ def get_args_parser(): help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') - parser.add_argument('--num-workers', default=4, type=int) + parser.add_argument('--num-workers', default=16, type=int) parser.add_argument('--pin-mem', action='store_true', help='Pin CPU memory in DataLoader') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') @@ -264,7 +264,7 @@ def main(args): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: - checkpoint = torch.load(args.resume, map_location='cpu') + checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False) model_without_ddp.load_state_dict(checkpoint['model']) if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: @@ -308,7 +308,7 @@ def main(args): train_stats, global_step = train_one_epoch( model, criterion, data_loader_train, - optimizer, device, epoch, loss_scaler, + optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode, model_ema=model_ema, writer=writer, global_step=global_step, args=args ) @@ -356,7 +356,7 @@ def main(args): def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, - clip_grad=None, clip_mode='norm', model_ema=None, writer=None, + clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None, global_step=0, args=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py index eeaa29b..b225b53 100644 --- a/models/swiftformer_temporal.py +++ b/models/swiftformer_temporal.py @@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): - """Upsampling block for frame prediction decoder with residual connections""" + """Upsampling block for frame prediction decoder without residual connections""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() # 主路径:反卷积 + 两个卷积层 @@ -31,28 +31,6 @@ class DecoderBlock(nn.Module): kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels) - # 残差路径:如果需要改变通道数或空间尺寸 - self.shortcut = nn.Identity() - if in_channels != out_channels or stride != 1: - # 使用1x1卷积调整通道数,如果需要上采样则使用反卷积 - if stride == 1: - self.shortcut = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), - nn.BatchNorm2d(out_channels) - ) - else: - self.shortcut = nn.Sequential( - nn.ConvTranspose2d( - in_channels, out_channels, - kernel_size=1, - stride=stride, - padding=0, - output_padding=output_padding, - bias=False - ), - nn.BatchNorm2d(out_channels) - ) - # 使用ReLU激活函数 self.activation = nn.ReLU(inplace=True) @@ -67,13 +45,6 @@ class DecoderBlock(nn.Module): nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu') - # 初始化shortcut - if not isinstance(self.shortcut, nn.Identity): - # shortcut现在是Sequential,需要初始化其中的卷积层 - for module in self.shortcut: - if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - # 初始化BN层(使用默认初始化) for m in self.modules(): if isinstance(m, nn.BatchNorm2d): @@ -81,8 +52,6 @@ class DecoderBlock(nn.Module): nn.init.constant_(m.bias, 0) def forward(self, x): - identity = self.shortcut(x) - # 主路径 x = self.conv_transpose(x) x = self.bn1(x) @@ -94,9 +63,6 @@ class DecoderBlock(nn.Module): x = self.conv2(x) x = self.bn3(x) - - # 残差连接 - x = x + identity x = self.activation(x) return x @@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module): """Improved decoder for frame prediction""" def __init__(self, embed_dims, output_channels=1): super().__init__() - # Reverse the embed_dims for decoder - decoder_dims = embed_dims[::-1] + # Define decoder dimensions independently (no skip connections) + start_dim = embed_dims[-1] + decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS self.blocks = nn.ModuleList() - # 调整顺序:将stride=4放在倒数第二的位置 - # 第一个block:stride=2 (220 -> 112) + # 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1]) self.blocks.append(DecoderBlock( decoder_dims[0], decoder_dims[1], kernel_size=3, stride=2, padding=1, output_padding=1 )) - # 第二个block:stride=2 (112 -> 56) + # 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2]) self.blocks.append(DecoderBlock( decoder_dims[1], decoder_dims[2], kernel_size=3, stride=2, padding=1, output_padding=1 )) - # 第三个block:stride=2 (56 -> 48) + # 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3]) self.blocks.append(DecoderBlock( decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) - # 第四个block:stride=4 (48 -> 64),放在倒数第二的位置 + # 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置 self.blocks.append(DecoderBlock( decoder_dims[3], 64, kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里 @@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module): nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True), - nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内 + nn.Tanh() ) def forward(self, x):