更新归一化方式,当前直接映射,不利用均值标准差进行标准化
This commit is contained in:
@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
from util import *
|
||||
from models import *
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
@@ -47,7 +47,7 @@ def get_args_parser():
|
||||
help='Number of input frames (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='Input frame size')
|
||||
parser.add_argument('--max-interval', default=1, type=int,
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
help='Maximum interval between consecutive frames')
|
||||
|
||||
# Model parameters
|
||||
@@ -109,10 +109,10 @@ def get_args_parser():
|
||||
help='Weight for frame prediction loss')
|
||||
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
||||
help='Weight for contrastive loss')
|
||||
parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
help='Weight for L1 loss')
|
||||
parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
help='Weight for SSIM loss')
|
||||
# parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
# help='Weight for L1 loss')
|
||||
# parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
# help='Weight for SSIM loss')
|
||||
parser.add_argument('--no-contrastive', action='store_true',
|
||||
help='Disable contrastive loss')
|
||||
parser.add_argument('--no-ssim', action='store_true',
|
||||
@@ -326,7 +326,7 @@ def main(args):
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
|
||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
|
||||
Reference in New Issue
Block a user