初步可跑通,但loss计算有问题,不收敛

This commit is contained in:
2026-01-08 09:43:23 +08:00
parent efd76bccd2
commit f7601e9170
11 changed files with 656 additions and 63 deletions

View File

@@ -6,9 +6,9 @@ import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
from timm.layers import DropPath, trunc_normal_
from timm.models import register_model
from timm.layers import to_2tuple
import einops
SwiftFormer_width = {

View File

@@ -7,7 +7,7 @@ from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.models.layers import DropPath, trunc_normal_
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
@@ -31,7 +31,7 @@ class DecoderBlock(nn.Module):
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False):
def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
@@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module):
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (4x upsampling total)
# stage2 to original resolution (now 8x upsampling total with stride 4)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1
kernel_size=3, stride=4, padding=1, output_padding=3
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
@@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
@@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
# Representation head for pose/velocity prediction
if use_representation_head:
@@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module):
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None)
pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)