diff --git a/evaluate_temporal.py b/evaluate_temporal.py index 3c7b2ae..c1553d7 100644 --- a/evaluate_temporal.py +++ b/evaluate_temporal.py @@ -45,6 +45,15 @@ def denormalize(tensor): # [0, 1] -> [0, 255] tensor = tensor * 255 return tensor.clamp(0, 255) + # return tensor + +def minmax_denormalize(tensor): + tensor_min = tensor.min() + tensor_max = tensor.max() + tensor = (tensor - tensor_min) / (tensor_max - tensor_min) + # tensor = tensor*2-1 + tensor = tensor*255 + return tensor.clamp(0, 255) def calculate_metrics(pred, target, debug=False): @@ -134,6 +143,10 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path, ax.set_title('Predicted') ax.axis('off') + #debug print + print(target_frame) + print(pred_frame) + # debug print - 改进为更有信息量的输出 if isinstance(pred_frame, np.ndarray): print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") @@ -161,8 +174,8 @@ def evaluate_model(model, data_loader, device, args): metrics_dict: 包含所有指标的字典 sample_results: 示例结果用于可视化 """ - # model.eval() - model.train() # 临时使用训练模式 + model.eval() + # model.train() # 临时使用训练模式 # 初始化指标累加器 total_mse = 0.0 @@ -183,10 +196,11 @@ def evaluate_model(model, data_loader, device, args): target_frames = target_frames.to(device, non_blocking=True) # 前向传播 - pred_frames, _ = model(input_frames) + pred_frames = model(input_frames) # 反归一化用于指标计算 - pred_denorm = denormalize(pred_frames) # [B, 1, H, W] + # pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W] + pred_denorm = denormalize(pred_frames) target_denorm = denormalize(target_frames) # [B, 1, H, W] batch_size = input_frames.size(0) @@ -309,8 +323,6 @@ def main(args): print(f"创建模型: {args.model}") model_kwargs = { 'num_frames': args.num_frames, - 'use_representation_head': args.use_representation_head, - 'representation_dim': args.representation_dim, } if args.model == 'SwiftFormerTemporal_XS': @@ -335,10 +347,10 @@ def main(args): except (pickle.UnpicklingError, TypeError) as e: print(f"使用weights_only=False加载失败: {e}") print("尝试使用torch.serialization.add_safe_globals...") - from argparse import Namespace - # 添加安全全局变量 - torch.serialization.add_safe_globals([Namespace]) - checkpoint = torch.load(args.resume, map_location='cpu') + # from argparse import Namespace + # # 添加安全全局变量 + # torch.serialization.add_safe_globals([Namespace]) + # checkpoint = torch.load(args.resume, map_location='cpu') # 处理状态字典(可能包含'module.'前缀) if 'model' in checkpoint: @@ -462,10 +474,6 @@ def get_args_parser(): # 模型参数 parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='要评估的模型名称') - parser.add_argument('--use-representation-head', action='store_true', - help='使用表示头进行姿态/速度预测') - parser.add_argument('--representation-dim', default=128, type=int, - help='表示向量的维度') # 评估参数 parser.add_argument('--batch-size', default=16, type=int, diff --git a/main_temporal.py b/main_temporal.py index 5a137e1..90328e8 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -49,11 +49,6 @@ def get_args_parser(): # Model parameters parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', help='Name of model to train') - parser.add_argument('--use-representation-head', action='store_true', - help='Use representation head for pose/velocity prediction') - parser.add_argument('--representation-dim', default=128, type=int, - help='Dimension of representation vector') - parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections') # Training parameters parser.add_argument('--batch-size', default=32, type=int) @@ -207,9 +202,6 @@ def main(args): print(f"Creating model: {args.model}") model_kwargs = { 'num_frames': args.num_frames, - 'use_representation_head': args.use_representation_head, - 'representation_dim': args.representation_dim, - 'use_skip': args.use_skip, } if args.model == 'SwiftFormerTemporal_XS': @@ -258,7 +250,7 @@ def main(args): super().__init__() self.mse = nn.MSELoss() - def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): + def forward(self, pred_frame, target_frame, temporal_indices=None): loss = self.mse(pred_frame, target_frame) loss_dict = {'mse': loss} return loss, loss_dict @@ -386,10 +378,10 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los # Forward pass with torch.amp.autocast(device_type='cuda'): - pred_frames, representations = model(input_frames) + pred_frames = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, - representations, temporal_indices + temporal_indices ) loss_value = loss.item() @@ -452,7 +444,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: with torch.no_grad(): # Take first sample from batch for visualization - pred_vis, _ = model(input_frames[:1]) + pred_vis = model(input_frames[:1]) # Convert to appropriate format for TensorBoard # Assuming frames are in [B, C, H, W] format writer.add_images('train/input', input_frames[:1], global_step) @@ -497,10 +489,10 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): # Compute output with torch.amp.autocast(device_type='cuda'): - pred_frames, representations = model(input_frames) + pred_frames = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, - representations, temporal_indices + temporal_indices ) # 计算诊断指标 @@ -555,4 +547,4 @@ if __name__ == '__main__': if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) - main(args) \ No newline at end of file + main(args) diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py index e91b569..31e8406 100644 --- a/models/swiftformer_temporal.py +++ b/models/swiftformer_temporal.py @@ -93,159 +93,33 @@ class DecoderBlock(nn.Module): return x -class DecoderBlockWithSkip(nn.Module): - """Decoder block with skip connection support""" - def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1): - super().__init__() - # 总输入通道 = 输入通道 + skip通道 - total_in_channels = in_channels + skip_channels - - # 主路径:反卷积 + 两个卷积层 - self.conv_transpose = nn.ConvTranspose2d( - total_in_channels, out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - bias=True - ) - self.conv1 = nn.Conv2d(out_channels, out_channels, - kernel_size=3, padding=1, bias=True) - self.conv2 = nn.Conv2d(out_channels, out_channels, - kernel_size=3, padding=1, bias=True) - - # 残差路径:如果需要改变通道数或空间尺寸 - self.shortcut = nn.Identity() - if total_in_channels != out_channels or stride != 1: - if stride == 1: - self.shortcut = nn.Conv2d(total_in_channels, out_channels, - kernel_size=1, bias=True) - else: - self.shortcut = nn.ConvTranspose2d( - total_in_channels, out_channels, - kernel_size=1, - stride=stride, - padding=0, - output_padding=output_padding, - bias=True - ) - - # 使用LeakyReLU避免死亡神经元 - self.activation = nn.LeakyReLU(0.2, inplace=True) - - # 初始化权重 - self._init_weights() - - def _init_weights(self): - # 初始化反卷积层 - nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv_transpose.bias is not None: - nn.init.constant_(self.conv_transpose.bias, 0) - - # 初始化卷积层 - nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv1.bias is not None: - nn.init.constant_(self.conv1.bias, 0) - - nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv2.bias is not None: - nn.init.constant_(self.conv2.bias, 0) - - # 初始化shortcut - if not isinstance(self.shortcut, nn.Identity): - if isinstance(self.shortcut, nn.Conv2d): - nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') - elif isinstance(self.shortcut, nn.ConvTranspose2d): - nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.shortcut.bias is not None: - nn.init.constant_(self.shortcut.bias, 0) - - def forward(self, x, skip_feature=None): - # 如果有skip feature,将其与输入拼接 - if skip_feature is not None: - # 确保skip特征的空间尺寸与x匹配 - if skip_feature.shape[2:] != x.shape[2:]: - # 使用双线性插值进行上采样或下采样 - skip_feature = torch.nn.functional.interpolate( - skip_feature, - size=x.shape[2:], - mode='bilinear', - align_corners=False - ) - x = torch.cat([x, skip_feature], dim=1) - - identity = self.shortcut(x) - - # 主路径 - x = self.conv_transpose(x) - x = self.activation(x) - - x = self.conv1(x) - x = self.activation(x) - - x = self.conv2(x) - - # 残差连接 - x = x + identity - x = self.activation(x) - return x - - class FramePredictionDecoder(nn.Module): - """Improved decoder for frame prediction with better upsampling strategy""" - def __init__(self, embed_dims, output_channels=1, use_skip=False): + """Improved decoder for frame prediction""" + def __init__(self, embed_dims, output_channels=1): super().__init__() - self.use_skip = use_skip # Reverse the embed_dims for decoder decoder_dims = embed_dims[::-1] self.blocks = nn.ModuleList() - if use_skip: - # 使用支持skip connections的block - # 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3 - self.blocks.append(DecoderBlockWithSkip( - decoder_dims[0], decoder_dims[1], - skip_channels=embed_dims[3], # stage3的通道数 - kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 - )) - # 第二个block:stage4到stage3,stride=2,skip来自stage2 - self.blocks.append(DecoderBlockWithSkip( - decoder_dims[1], decoder_dims[2], - skip_channels=embed_dims[2], # stage2的通道数 - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # 第三个block:stage3到stage2,stride=2,skip来自stage1 - self.blocks.append(DecoderBlockWithSkip( - decoder_dims[2], decoder_dims[3], - skip_channels=embed_dims[1], # stage1的通道数 - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # 第四个block:stage2到stage1,stride=2,skip来自stage0 - self.blocks.append(DecoderBlockWithSkip( - decoder_dims[3], 64, # 输出到64通道 - skip_channels=embed_dims[0], # stage0的通道数 - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - else: - # 使用普通的DecoderBlock,第一个block使用大步长 - self.blocks.append(DecoderBlock( - decoder_dims[0], decoder_dims[1], - kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 - )) - self.blocks.append(DecoderBlock( - decoder_dims[1], decoder_dims[2], - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - self.blocks.append(DecoderBlock( - decoder_dims[2], decoder_dims[3], - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # 第四个block:增加到64通道 - self.blocks.append(DecoderBlock( - decoder_dims[3], 64, - kernel_size=3, stride=2, padding=1, output_padding=1 - )) + # 使用普通的DecoderBlock,第一个block使用大步长 + self.blocks.append(DecoderBlock( + decoder_dims[0], decoder_dims[1], + kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 + )) + self.blocks.append(DecoderBlock( + decoder_dims[1], decoder_dims[2], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + self.blocks.append(DecoderBlock( + decoder_dims[2], decoder_dims[3], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # 第四个block:增加到64通道 + self.blocks.append(DecoderBlock( + decoder_dims[3], 64, + kernel_size=3, stride=2, padding=1, output_padding=1 + )) # 改进的最终输出层:不使用反卷积,只进行特征精炼 # 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合 @@ -258,48 +132,14 @@ class FramePredictionDecoder(nn.Module): # 移除Tanh,让输出在任意范围,由损失函数和归一化处理 ) - def forward(self, x, skip_features=None): + def forward(self, x): """ Args: x: input tensor of shape [B, embed_dims[-1], H/32, W/32] - skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0] - each of shape [B, C, H', W'] where C matches encoder dims """ - if self.use_skip: - if skip_features is None: - raise ValueError("skip_features must be provided when use_skip=True") - - # 确保有4个skip features - assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}" - - # 反转顺序以匹配解码器:stage3, stage2, stage1, stage0 - skip_features = skip_features[::-1] - - # 调整skip特征的尺寸以匹配新的上采样策略 - adjusted_skip_features = [] - for i, skip in enumerate(skip_features): - if skip is not None: - # 计算目标尺寸:4, 2, 2, 2倍上采样 - upsample_factors = [4, 2, 2, 2] - target_height = x.shape[2] * upsample_factors[i] - target_width = x.shape[3] * upsample_factors[i] - - if skip.shape[2:] != (target_height, target_width): - skip = torch.nn.functional.interpolate( - skip, - size=(target_height, target_width), - mode='bilinear', - align_corners=False - ) - adjusted_skip_features.append(skip) - - # 四个block使用skip connections - for i in range(4): - x = self.blocks[i](x, adjusted_skip_features[i]) - else: - # 不使用skip connections - for i in range(4): - x = self.blocks[i](x) + # 不使用skip connections + for i in range(4): + x = self.blocks[i](x) # 最终输出层:只进行特征精炼,不上采样 x = self.final_block(x) @@ -316,9 +156,6 @@ class SwiftFormerTemporal(nn.Module): model_name='XS', num_frames=3, use_decoder=True, - use_skip=True, # 新增:是否使用skip connections - use_representation_head=False, - representation_dim=128, return_features=False, **kwargs): super().__init__() @@ -330,8 +167,6 @@ class SwiftFormerTemporal(nn.Module): # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder - self.use_skip = use_skip # 保存skip connections设置 - self.use_representation_head = use_representation_head self.return_features = return_features # Modify stem to accept multiple frames (only Y channel) @@ -365,22 +200,9 @@ class SwiftFormerTemporal(nn.Module): if use_decoder: self.decoder = FramePredictionDecoder( embed_dims, - output_channels=1, - use_skip=use_skip # 传递skip connections设置 + output_channels=1 ) - # Representation head for pose/velocity prediction - if use_representation_head: - self.representation_head = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Flatten(), - nn.Linear(embed_dims[-1], representation_dim), - nn.ReLU(), - nn.Linear(representation_dim, representation_dim) - ) - else: - self.representation_head = None - self.apply(self._init_weights) def _init_weights(self, m): @@ -400,7 +222,7 @@ class SwiftFormerTemporal(nn.Module): def forward_tokens(self, x): """Forward through encoder network, return list of stage features if return_features else final output""" - if self.return_features or self.use_skip: + if self.return_features: features = [] stage_idx = 0 for idx, block in enumerate(self.network): @@ -423,59 +245,37 @@ class SwiftFormerTemporal(nn.Module): Returns: If return_features is False: pred_frame: predicted frame [B, 1, H, W] (or None) - representation: optional representation vector [B, representation_dim] (or None) If return_features is True: - pred_frame, representation, features (list of stage features) + pred_frame, features (list of stage features) """ # Encode x = self.patch_embed(x) - if self.return_features or self.use_skip: + if self.return_features: x, features = self.forward_tokens(x) else: x = self.forward_tokens(x) x = self.norm(x) - # Get representation if needed - representation = None - if self.representation_head is not None: - representation = self.representation_head(x) - # Decode to frame pred_frame = None if self.use_decoder: - if self.use_skip: - # 提取用于skip connections的特征 - # features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3 - # 根据SwiftFormer结构,应该有4个stage特征 - if len(features) >= 4: - # 取四个stage的特征:stage0, stage1, stage2, stage3 - skip_features = [features[0], features[1], features[2], features[3]] - else: - # 如果特征不够,使用可用的特征 - skip_features = features[:4] - # 如果特征仍然不够,使用None填充 - while len(skip_features) < 4: - skip_features.append(None) - - pred_frame = self.decoder(x, skip_features) - else: - pred_frame = self.decoder(x) + pred_frame = self.decoder(x) if self.return_features: - return pred_frame, representation, features + return pred_frame, features else: - return pred_frame, representation + return pred_frame # Factory functions for different model sizes -def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs): - return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs) +def SwiftFormerTemporal_XS(num_frames=3, **kwargs): + return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs) -def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs): - return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs) +def SwiftFormerTemporal_S(num_frames=3, **kwargs): + return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs) -def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs): - return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs) +def SwiftFormerTemporal_L1(num_frames=3, **kwargs): + return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs) -def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs): - return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs) \ No newline at end of file +def SwiftFormerTemporal_L3(num_frames=3, **kwargs): + return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs) \ No newline at end of file