删除残差路径和shortcut,镜像问题仍存在
This commit is contained in:
@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
||||
"""Upsampling block for frame prediction decoder without residual connections"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
@@ -31,28 +31,6 @@ class DecoderBlock(nn.Module):
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if in_channels != out_channels or stride != 1:
|
||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
else:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=False
|
||||
),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
|
||||
# 使用ReLU激活函数
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
@@ -67,13 +45,6 @@ class DecoderBlock(nn.Module):
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
# shortcut现在是Sequential,需要初始化其中的卷积层
|
||||
for module in self.shortcut:
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
# 初始化BN层(使用默认初始化)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
@@ -81,8 +52,6 @@ class DecoderBlock(nn.Module):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.bn1(x)
|
||||
@@ -94,9 +63,6 @@ class DecoderBlock(nn.Module):
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.bn3(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
@@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module):
|
||||
"""Improved decoder for frame prediction"""
|
||||
def __init__(self, embed_dims, output_channels=1):
|
||||
super().__init__()
|
||||
# Reverse the embed_dims for decoder
|
||||
decoder_dims = embed_dims[::-1]
|
||||
# Define decoder dimensions independently (no skip connections)
|
||||
start_dim = embed_dims[-1]
|
||||
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
# 调整顺序:将stride=4放在倒数第二的位置
|
||||
# 第一个block:stride=2 (220 -> 112)
|
||||
# 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第二个block:stride=2 (112 -> 56)
|
||||
# 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第三个block:stride=2 (56 -> 48)
|
||||
# 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3])
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:stride=4 (48 -> 64),放在倒数第二的位置
|
||||
# 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[3], 64,
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||
@@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module):
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user