删除残差路径和shortcut,镜像问题仍存在
This commit is contained in:
@@ -43,7 +43,7 @@ def get_args_parser():
|
||||
help='Number of input frames (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='Input frame size')
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
parser.add_argument('--max-interval', default=10, type=int,
|
||||
help='Maximum interval between consecutive frames')
|
||||
|
||||
# Model parameters
|
||||
@@ -121,7 +121,7 @@ def get_args_parser():
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--num-workers', default=4, type=int)
|
||||
parser.add_argument('--num-workers', default=16, type=int)
|
||||
parser.add_argument('--pin-mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader')
|
||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||
@@ -264,7 +264,7 @@ def main(args):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.resume, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||
@@ -308,7 +308,7 @@ def main(args):
|
||||
|
||||
train_stats, global_step = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
|
||||
model_ema=model_ema, writer=writer,
|
||||
global_step=global_step, args=args
|
||||
)
|
||||
@@ -356,7 +356,7 @@ def main(args):
|
||||
|
||||
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
|
||||
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
|
||||
global_step=0, args=None, **kwargs):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
|
||||
Reference in New Issue
Block a user