完善了跳连接,在上decode块后增加特征精炼层,未测效果
This commit is contained in:
@@ -11,26 +11,188 @@ from timm.layers import DropPath, trunc_normal_
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
"""Upsampling block for frame prediction decoder"""
|
||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
self.conv_transpose = nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
bias=False
|
||||
bias=True # 启用bias,因为移除了BN
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if in_channels != out_channels or stride != 1:
|
||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, bias=True)
|
||||
else:
|
||||
self.shortcut = nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# 使用LeakyReLU避免死亡神经元
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
# 初始化权重
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# 初始化反卷积层
|
||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv_transpose.bias is not None:
|
||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
||||
|
||||
# 初始化卷积层
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv1.bias is not None:
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv2.bias is not None:
|
||||
nn.init.constant_(self.conv2.bias, 0)
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
if isinstance(self.shortcut, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.shortcut.bias is not None:
|
||||
nn.init.constant_(self.shortcut.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderBlockWithSkip(nn.Module):
|
||||
"""Decoder block with skip connection support"""
|
||||
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
# 总输入通道 = 输入通道 + skip通道
|
||||
total_in_channels = in_channels + skip_channels
|
||||
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
self.conv_transpose = nn.ConvTranspose2d(
|
||||
total_in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if total_in_channels != out_channels or stride != 1:
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
|
||||
kernel_size=1, bias=True)
|
||||
else:
|
||||
self.shortcut = nn.ConvTranspose2d(
|
||||
total_in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# 使用LeakyReLU避免死亡神经元
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
# 初始化权重
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# 初始化反卷积层
|
||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv_transpose.bias is not None:
|
||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
||||
|
||||
# 初始化卷积层
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv1.bias is not None:
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv2.bias is not None:
|
||||
nn.init.constant_(self.conv2.bias, 0)
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
if isinstance(self.shortcut, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.shortcut.bias is not None:
|
||||
nn.init.constant_(self.shortcut.bias, 0)
|
||||
|
||||
def forward(self, x, skip_feature=None):
|
||||
# 如果有skip feature,将其与输入拼接
|
||||
if skip_feature is not None:
|
||||
# 确保skip特征的空间尺寸与x匹配
|
||||
if skip_feature.shape[2:] != x.shape[2:]:
|
||||
# 使用双线性插值进行上采样或下采样
|
||||
skip_feature = torch.nn.functional.interpolate(
|
||||
skip_feature,
|
||||
size=x.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
x = torch.cat([x, skip_feature], dim=1)
|
||||
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class FramePredictionDecoder(nn.Module):
|
||||
"""Lightweight decoder for frame prediction with optional skip connections"""
|
||||
"""Improved decoder for frame prediction with better upsampling strategy"""
|
||||
def __init__(self, embed_dims, output_channels=1, use_skip=False):
|
||||
super().__init__()
|
||||
self.use_skip = use_skip
|
||||
@@ -38,65 +200,109 @@ class FramePredictionDecoder(nn.Module):
|
||||
decoder_dims = embed_dims[::-1]
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
# First upsampling from bottleneck to stage4 resolution
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage4 to stage3
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage3 to stage2
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage2 to original resolution (now 8x upsampling total with stride 4)
|
||||
self.blocks.append(nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
decoder_dims[3], 32,
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3
|
||||
),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
|
||||
nn.Tanh() # Output in [-1, 1] range
|
||||
))
|
||||
|
||||
# If using skip connections, we need to adjust input channels for each block
|
||||
if use_skip:
|
||||
# We'll modify the first three blocks to accept concatenated features
|
||||
# Instead of modifying existing blocks, we'll replace them with custom blocks
|
||||
# For simplicity, we'll keep the same architecture but forward will handle concatenation
|
||||
pass
|
||||
# 使用支持skip connections的block
|
||||
# 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
skip_channels=embed_dims[3], # stage3的通道数
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
||||
))
|
||||
# 第二个block:stage4到stage3,stride=2,skip来自stage2
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
skip_channels=embed_dims[2], # stage2的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第三个block:stage3到stage2,stride=2,skip来自stage1
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
skip_channels=embed_dims[1], # stage1的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:stage2到stage1,stride=2,skip来自stage0
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[3], 64, # 输出到64通道
|
||||
skip_channels=embed_dims[0], # stage0的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
else:
|
||||
# 使用普通的DecoderBlock,第一个block使用大步长
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
||||
))
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:增加到64通道
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[3], 64,
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
|
||||
# 改进的最终输出层:不使用反卷积,只进行特征精炼
|
||||
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
|
||||
self.final_block = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
|
||||
# 移除Tanh,让输出在任意范围,由损失函数和归一化处理
|
||||
)
|
||||
|
||||
def forward(self, x, skip_features=None):
|
||||
"""
|
||||
Args:
|
||||
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
||||
skip_features: list of encoder features from stages [stage2, stage1, stage0]
|
||||
each of shape [B, C, H', W'] where C matches decoder dims?
|
||||
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
|
||||
each of shape [B, C, H', W'] where C matches encoder dims
|
||||
"""
|
||||
if self.use_skip and skip_features is not None:
|
||||
# Ensure we have exactly 3 skip features (for the first three blocks)
|
||||
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
|
||||
# Reverse skip_features to match decoder order: stage2, stage1, stage0
|
||||
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
|
||||
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
|
||||
if self.use_skip:
|
||||
if skip_features is None:
|
||||
raise ValueError("skip_features must be provided when use_skip=True")
|
||||
|
||||
# 确保有4个skip features
|
||||
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
|
||||
|
||||
# 反转顺序以匹配解码器:stage3, stage2, stage1, stage0
|
||||
skip_features = skip_features[::-1]
|
||||
|
||||
# 调整skip特征的尺寸以匹配新的上采样策略
|
||||
adjusted_skip_features = []
|
||||
for i, skip in enumerate(skip_features):
|
||||
if skip is not None:
|
||||
# 计算目标尺寸:4, 2, 2, 2倍上采样
|
||||
upsample_factors = [4, 2, 2, 2]
|
||||
target_height = x.shape[2] * upsample_factors[i]
|
||||
target_width = x.shape[3] * upsample_factors[i]
|
||||
|
||||
if skip.shape[2:] != (target_height, target_width):
|
||||
skip = torch.nn.functional.interpolate(
|
||||
skip,
|
||||
size=(target_height, target_width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
adjusted_skip_features.append(skip)
|
||||
|
||||
# 四个block使用skip connections
|
||||
for i in range(4):
|
||||
x = self.blocks[i](x, adjusted_skip_features[i])
|
||||
else:
|
||||
# 不使用skip connections
|
||||
for i in range(4):
|
||||
x = self.blocks[i](x)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
if self.use_skip and skip_features is not None and i < 3:
|
||||
# Concatenate skip feature along channel dimension
|
||||
# Ensure spatial dimensions match (they should because of upsampling)
|
||||
x = torch.cat([x, skip_features[i]], dim=1)
|
||||
# Need to adjust block to accept extra channels? We'll create a separate block.
|
||||
# For now, we'll just pass through, but this will cause channel mismatch.
|
||||
# Instead, we should have created custom blocks with appropriate in_channels.
|
||||
# This is a placeholder; we need to implement properly.
|
||||
pass
|
||||
x = block(x)
|
||||
# 最终输出层:只进行特征精炼,不上采样
|
||||
x = self.final_block(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -106,10 +312,11 @@ class SwiftFormerTemporal(nn.Module):
|
||||
Input: [B, num_frames, H, W] (Y channel only)
|
||||
Output: predicted frame [B, 1, H, W] and optional representation
|
||||
"""
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
model_name='XS',
|
||||
num_frames=3,
|
||||
use_decoder=True,
|
||||
use_skip=True, # 新增:是否使用skip connections
|
||||
use_representation_head=False,
|
||||
representation_dim=128,
|
||||
return_features=False,
|
||||
@@ -123,6 +330,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
# Store configuration
|
||||
self.num_frames = num_frames
|
||||
self.use_decoder = use_decoder
|
||||
self.use_skip = use_skip # 保存skip connections设置
|
||||
self.use_representation_head = use_representation_head
|
||||
self.return_features = return_features
|
||||
|
||||
@@ -155,7 +363,11 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
# Frame prediction decoder
|
||||
if use_decoder:
|
||||
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
|
||||
self.decoder = FramePredictionDecoder(
|
||||
embed_dims,
|
||||
output_channels=1,
|
||||
use_skip=use_skip # 传递skip connections设置
|
||||
)
|
||||
|
||||
# Representation head for pose/velocity prediction
|
||||
if use_representation_head:
|
||||
@@ -173,22 +385,31 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
# 使用Kaiming初始化,适合ReLU/LeakyReLU
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm)):
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
# 反卷积层使用特定的初始化
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward_tokens(self, x):
|
||||
"""Forward through encoder network, return list of stage features if return_features else final output"""
|
||||
if self.return_features:
|
||||
if self.return_features or self.use_skip:
|
||||
features = []
|
||||
stage_idx = 0
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
|
||||
# 收集每个stage的输出(stage0, stage1, stage2, stage3)
|
||||
# 根据SwiftFormer结构,stage在索引0,2,4,6位置
|
||||
if idx in [0, 2, 4, 6]:
|
||||
features.append(x)
|
||||
stage_idx += 1
|
||||
return x, features
|
||||
else:
|
||||
for block in self.network:
|
||||
@@ -208,7 +429,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
"""
|
||||
# Encode
|
||||
x = self.patch_embed(x)
|
||||
if self.return_features:
|
||||
if self.return_features or self.use_skip:
|
||||
x, features = self.forward_tokens(x)
|
||||
else:
|
||||
x = self.forward_tokens(x)
|
||||
@@ -222,7 +443,23 @@ class SwiftFormerTemporal(nn.Module):
|
||||
# Decode to frame
|
||||
pred_frame = None
|
||||
if self.use_decoder:
|
||||
pred_frame = self.decoder(x)
|
||||
if self.use_skip:
|
||||
# 提取用于skip connections的特征
|
||||
# features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3
|
||||
# 根据SwiftFormer结构,应该有4个stage特征
|
||||
if len(features) >= 4:
|
||||
# 取四个stage的特征:stage0, stage1, stage2, stage3
|
||||
skip_features = [features[0], features[1], features[2], features[3]]
|
||||
else:
|
||||
# 如果特征不够,使用可用的特征
|
||||
skip_features = features[:4]
|
||||
# 如果特征仍然不够,使用None填充
|
||||
while len(skip_features) < 4:
|
||||
skip_features.append(None)
|
||||
|
||||
pred_frame = self.decoder(x, skip_features)
|
||||
else:
|
||||
pred_frame = self.decoder(x)
|
||||
|
||||
if self.return_features:
|
||||
return pred_frame, representation, features
|
||||
@@ -231,14 +468,14 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
|
||||
# Factory functions for different model sizes
|
||||
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
Reference in New Issue
Block a user