Initial release of SwiftFormer

This commit is contained in:
amshaker
2023-03-26 23:31:59 +04:00
commit 574907f49b
19 changed files with 1751 additions and 0 deletions

142
README.md Normal file
View File

@@ -0,0 +1,142 @@
# SwiftFormer
### **SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications**
[Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ),
[Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra),
[Hanoona Rasheed](https://scholar.google.com/citations?user=yhDdEuEAAAAJ&hl=en&authuser=1&oi=sra),
[Salman Khan](https://salman-h-khan.github.io),
[Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en),
and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en)
<!-- [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](site_url) -->
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](arxiv_link)
<!-- [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](youtube_link) -->
<!-- [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](presentation) -->
## :rocket: News
* **(Mar 27, 2023):** Classification training and evaluation codes along with pre-trained models are released.
<hr />
<p align="center">
<img src="images/Swiftformer_performance.png" width=60%> <br>
Comparison of our SwiftFormer Models with state-of-the-art on ImgeNet-1K. The latency is measured on iPhone 14 Neural Engine (iOS 16).
</p>
<p align="center">
<img src="images/attentions_comparison.png" width=99%> <br>
</p>
<p align="left">
Comparison with different self-attention modules. (a) is a typical self-attention. (b) is the transpose self-attention, where the self-attention operation is applied across channel feature dimensions (d×d) instead of the spatial dimension (n×n). (c) is the separable self-attention of MobileViT-v2, it uses element-wise operations to compute the context vector from the interactions of Q and K matrices. Then, the context vector is multiplied by V matrix to produce the final output. (d) Our proposed efficient additive self-attention. Here, the query matrix is multiplied by learnable weights and pooled to produce global queries. Then, the matrix K is element-wise multiplied by the broadcasted global queries, resulting the global context representation.
</p>
<details>
<summary>
<font size="+1">Abstract</font>
</summary>
Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8~ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.
</details>
<br>
## Classification on ImageNet-1K
### Models
| Model | Top-1 accuracy | #params | GMACs | Latency | Ckpt | CoreML|
|:---------------|:----:|:---:|:--:|:--:|:--:|:--:|
| SwiftFormer-XS | 75.7% | 3.5M | 0.4G | 0.7ms | [XS](https://drive.google.com/file/d/15Ils-U96pQePXQXx2MpmaI-yAceFAr2x/view?usp=sharing) | [XS](https://drive.google.com/file/d/1tZVxtbtAZoLLoDc5qqoUGulilksomLeK/view?usp=sharing) |
| SwiftFormer-S | 78.5% | 6.1M | 1.0G | 0.8ms | [S](https://drive.google.com/file/d/1_0eWwgsejtS0bWGBQS3gwAtYjXdPRGlu/view?usp=sharing) | [S](https://drive.google.com/file/d/13EOCZmtvbMR2V6UjezSZnbBz2_-59Fva/view?usp=sharing) |
| SwiftFormer-L1 | 80.9% | 12.1M | 1.6G | 1.1ms | [L1](https://drive.google.com/file/d/1jlwrwWQ0SQzDRc5adtWIwIut5d1g9EsM/view?usp=sharing) | [L1](https://drive.google.com/file/d/1c3VUsi4q7QQ2ykXVS2d4iCRL478fWF3e/view?usp=sharing) |
| SwiftFormer-L3 | 83.0% | 26.5M | 4.0G | 1.9ms | [L3](https://drive.google.com/file/d/1ypBcjx04ShmPYRhhjBRubiVjbExUgSa7/view?usp=sharing) | [L3](https://drive.google.com/file/d/1svahgIjh7da781jHOHjX58mtzCzYXSsJ/view?usp=sharing) |
## Detection and Segmentation Qualitative Results
<p align="center">
<img src="images/detection_seg.png" width=100%> <br>
</p>
<p align="center">
<img src="images/semantic_seg.png" width=100%> <br>
</p>
## Latency Measurement
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
## ImageNet
### Prerequisites
`conda` virtual environment is recommended.
```shell
conda create --name=swiftformer python=3.9
conda activate swiftformer
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install timm
```
### Data preparation
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
```
|-- /path/to/imagenet/
|-- train
|-- val
```
### Single machine multi-GPU training
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
To train SwiftFormer models on an 8-GPU machine:
```
sh dist_train.sh /path/to/imagenet 8
```
Note: specify which model command you want to run in the script. To reproduce the results of the paper, use 16-GPU machine with batch-size of 128 or 8-GPU machine with batch size of 256. Auto Augmentation, CutMix, MixUp are disabled for SwiftFormer-XS only.
### Multi-node training
On a Slurm-managed cluster, multi-node training can be launched as
```
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
```
Note: specify slurm specific paramters in `slurm_train.sh` script.
### Testing
We provide an example test script `dist_test.sh` using PyTorch distributed data parallel (DDP).
For example, to test SwiftFormer-XS on an 8-GPU machine:
```
sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth
```
## Citation
if you use our work, please consider citing us:
```BibTeX
@article{Shaker2023SwiftFormer,
title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
author={Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
journal={arXiv preprint arXiv:X.X},
year={2023}
}
```
## Contact:
If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
## Acknowledgement
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation.
## Our Related Works
- EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications, CADL'22, ECCV. [Paper](https://arxiv.org/abs/2206.10589) | [Code](https://github.com/mmaaz60/EdgeNeXt).

11
dist_test.sh Normal file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
IMAGENET_PATH=$1
MODEL=$2
CHECKPOINT=$3
nGPUs=$4
python -m torch.distributed.launch --master_addr="127.0.0.1" --master_port=1234 --nproc_per_node=$nGPUs --use_env main.py --model "$MODEL" \
--resume $CHECKPOINT --eval \
--data-path "$IMAGENET_PATH" \
--output_dir SwiftFormer_test

21
dist_train.sh Normal file
View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
IMAGENET_PATH=$1
nGPUs=$2
## SwiftFormer-XS
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_XS --aa="" --mixup 0 --cutmix 0 --data-path "$IMAGENET_PATH" \
--output_dir SwiftFormer_XS_results
## SwiftFormer-S
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_S --data-path "$IMAGENET_PATH" \
--output_dir SwiftFormer_S_results
## SwiftFormer-L1
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L1 --data-path "$IMAGENET_PATH" \
--output_dir SwiftFormer_L1_results
## SwiftFormer-L3
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L3 --data-path "$IMAGENET_PATH" \
--output_dir SwiftFormer_L3_results

Binary file not shown.

After

Width:  |  Height:  |  Size: 477 KiB

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 669 KiB

BIN
images/detection_seg.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 MiB

BIN
images/semantic_seg.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 MiB

412
main.py Normal file
View File

@@ -0,0 +1,412 @@
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormer training and evaluation script', add_help=False)
parser.add_argument('--batch-size', default=128, type=int)
parser.add_argument('--epochs', default=300, type=int)
# Model parameters
parser.add_argument('--model', default='SwiftFormer_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input-size', default=224,
type=int, help='images input size')
parser.add_argument('--model-ema', action='store_true')
parser.add_argument(
'--no-model-ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=True)
parser.add_argument('--model-ema-decay', type=float,
default=0.99996, help='')
parser.add_argument('--model-ema-force-cpu',
action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.025,
help='weight decay (default: 0.025)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=2e-3, metavar='LR',
help='learning rate (default: 2e-3)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + \
"(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--repeated-aug', action='store_true')
parser.add_argument('--no-repeated-aug',
action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str,
default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth')
parser.add_argument('--distillation-type', default='hard',
choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha',
default=0.5, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# * Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
# Dataset parameters
parser.add_argument('--data-path', default='./imagenet', type=str,
help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Image Net dataset path')
parser.add_argument('--inat-category', default='name',
choices=['kingdom', 'phylum', 'class', 'order',
'supercategory', 'family', 'genus', 'name'],
type=str, help='semantic granularity')
parser.add_argument('--output_dir', default='',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist-eval', action='store_true',
default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
utils.init_distributed_mode(args)
print(args)
if args.distillation_type != 'none' and args.finetune and not args.eval:
raise NotImplementedError(
"Finetuning with distillation not yet supported")
device = torch.device(args.device)
# Fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
if args.repeated_aug:
sampler_train = RASampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1.5 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
print(f"Creating model: {args.model}")
model = create_model(
args.model,
num_classes=args.nb_classes,
distillation=(args.distillation_type != 'none'),
pretrained=args.eval,
fuse=args.eval,
)
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias',
'head_dist.weight', 'head_dist.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
model.load_state_dict(checkpoint_model, strict=False)
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but
# before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
# better not to scale up lr for AdamW optimizer
# linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
# args.lr = linear_scaled_lr
optimizer = create_optimizer(args, model_without_ddp)
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
if args.mixup > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()
# Wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is
# 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.model_ema:
utils._load_checkpoint_for_ema(
model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
test_stats = evaluate(data_loader_val, model, device)
print(
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, args.clip_mode, model_ema, mixup_fn,
set_training_mode=args.finetune == '' # keep in eval mode during finetuning
)
lr_scheduler.step(epoch)
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
if epoch % 20 == 19:
test_stats = evaluate(data_loader_val, model, device)
print(
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormer training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

1
models/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3

507
models/swiftformer.py Normal file
View File

@@ -0,0 +1,507 @@
"""
SwiftFormer
"""
import os
import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
import einops
SwiftFormer_width = {
'XS': [48, 56, 112, 220],
'S': [48, 64, 168, 224],
'l1': [48, 96, 192, 384],
'l3': [64, 128, 320, 512],
}
SwiftFormer_depth = {
'XS': [3, 3, 6, 4],
'S': [3, 3, 9, 6],
'l1': [4, 3, 10, 5],
'l3': [4, 4, 12, 6],
}
CoreMLConversion = False
def stem(in_chs, out_chs):
"""
Stem Layer that is implemented by two layers of conv.
Output: sequence of layers with final shape of [B, C, H/4, W/4]
"""
return nn.Sequential(
nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_chs // 2),
nn.ReLU(),
nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_chs),
nn.ReLU(), )
class Embedding(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class ConvEncoder(nn.Module):
"""
Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
Output: tensor with shape [B, C, H, W]
"""
def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
self.norm = nn.BatchNorm2d(dim)
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.use_layer_scale:
x = input + self.drop_path(self.layer_scale * x)
else:
x = input + self.drop_path(x)
return x
class Mlp(nn.Module):
"""
Implementation of MLP layer with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
Output: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm1 = nn.BatchNorm2d(in_features)
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.norm1(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class EfficientAdditiveAttnetion(nn.Module):
"""
Efficient Additive Attention module for SwiftFormer.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
def __init__(self, in_dims=512, token_dim=256, num_heads=2):
super().__init__()
self.to_query = nn.Linear(in_dims, token_dim * num_heads)
self.to_key = nn.Linear(in_dims, token_dim * num_heads)
self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
self.scale_factor = token_dim ** -0.5
self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
self.final = nn.Linear(token_dim * num_heads, token_dim)
def forward(self, x):
query = self.to_query(x)
key = self.to_key(x)
if not CoreMLConversion:
# torch.nn.functional.normalize is not supported by the ANE of iPhone devices.
# Using this layer improves the accuracy by ~0.1-0.2%
query = torch.nn.functional.normalize(query, dim=-1)
key = torch.nn.functional.normalize(key, dim=-1)
query_weight = query @ self.w_g
A = query_weight * self.scale_factor
A = A.softmax(dim=-1)
G = torch.sum(A * query, dim=1)
G = einops.repeat(
G, "b d -> b repeat d", repeat=key.shape[1]
)
out = self.Proj(G * key) + query
out = self.final(out)
return out
class SwiftFormerLocalRepresentation(nn.Module):
"""
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
self.norm = nn.BatchNorm2d(dim)
self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1)
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
input = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.use_layer_scale:
x = input + self.drop_path(self.layer_scale * x)
else:
x = input + self.drop_path(x)
return x
class SwiftFormerEncoder(nn.Module):
"""
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
def __init__(self, dim, mlp_ratio=4.,
act_layer=nn.GELU,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0.,
use_layer_scale=True)
self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1)
self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
def forward(self, x):
x = self.local_representation(x)
B, C, H, W = x.shape
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(
0, 3, 1, 2))
x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
else:
x = x + self.drop_path(
self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2))
x = x + self.drop_path(self.linear(x))
return x
def Stage(dim, index, layers, mlp_ratio=4.,
act_layer=nn.GELU,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1):
"""
Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H, W]
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
if layers[index] - block_idx <= vit_num:
blocks.append(SwiftFormerEncoder(
dim, mlp_ratio=mlp_ratio,
act_layer=act_layer, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value))
else:
blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3))
blocks = nn.Sequential(*blocks)
return blocks
class SwiftFormer(nn.Module):
def __init__(self, layers, embed_dims=None,
mlp_ratios=4, downsamples=None,
act_layer=nn.GELU,
num_classes=1000,
down_patch_size=3, down_stride=2, down_pad=1,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5,
fork_feat=False,
init_cfg=None,
pretrained=None,
vit_num=1,
distillation=True,
**kwargs):
super().__init__()
if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat
self.patch_embed = stem(3, embed_dims[0])
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios,
act_layer=act_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
vit_num=vit_num)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
Embedding(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
layer = nn.Identity()
else:
layer = nn.BatchNorm2d(embed_dims[i_emb])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
# Classifier head
self.norm = nn.BatchNorm2d(embed_dims[-1])
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
self.dist = distillation
if self.dist:
self.dist_head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()
# self.apply(self.cls_init_weights)
self.apply(self._init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model
if self.fork_feat and (
self.init_cfg is not None or pretrained is not None):
self.init_weights()
# init for mmdetection or mmsegmentation by loading
# imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
return outs
return x
def forward(self, x):
x = self.patch_embed(x)
x = self.forward_tokens(x)
if self.fork_feat:
# Output features of four stages for dense prediction
return x
x = self.norm(x)
if self.dist:
cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1))
if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2
else:
cls_out = self.head(x.mean(-2))
# For image classification
return cls_out
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'classifier': 'head',
**kwargs
}
@register_model
def SwiftFormer_XS(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['XS'],
embed_dims=SwiftFormer_width['XS'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_S(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['S'],
embed_dims=SwiftFormer_width['S'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_L1(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['l1'],
embed_dims=SwiftFormer_width['l1'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model
@register_model
def SwiftFormer_L3(pretrained=False, **kwargs):
model = SwiftFormer(
layers=SwiftFormer_depth['l3'],
embed_dims=SwiftFormer_width['l3'],
downsamples=[True, True, True, True],
vit_num=1,
**kwargs)
model.default_cfg = _cfg(crop_pct=0.9)
return model

3
requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
torch==1.11.0+cu113
torchvision==0.12.0+cu113
timm==0.5.4

23
slurm_train.sh Normal file
View File

@@ -0,0 +1,23 @@
#!/bin/sh
#SBATCH --job-name=swiftformer
#SBATCH --partition=your_partition
#SBATCH --time=48:00:00
#SBATCH --nodes=4
#SBATCH --ntasks=16
#SBATCH --cpus-per-task=16
#SBATCH --gres=gpu:4
#SBATCH --mem-per-cpu=8000
IMAGENET_PATH=$1
MODEL=$2
srun python main.py --model "$MODEL" \
--data-path "$IMAGENET_PATH" \
--batch-size 128 \
--epochs 300 \
--aa="" --mixup 0 --cutmix 0
## Note: Disable aa, mixup, and cutmix for SwiftFormer-XS only
## By default, this script requests total 16 GPUs on 4 nodes. The batch size per gpu is set to 128,
## tha sums to 128*16=2048 in total.

6
util/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
import util.utils as utils
from .datasets import build_dataset
from .engine import train_one_epoch, evaluate
from .losses import DistillationLoss
from .samplers import RASampler

120
util/datasets.py Normal file
View File

@@ -0,0 +1,120 @@
import os
import json
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
import torch
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
class INatDataset(ImageFolder):
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, category='name',
loader=default_loader):
super().__init__(root, transform, target_transform, loader)
self.transform = transform
self.loader = loader
self.target_transform = target_transform
self.year = year
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
path_json = os.path.join(
root, f'{"train" if train else "val"}{year}.json')
with open(path_json) as json_file:
data = json.load(json_file)
with open(os.path.join(root, 'categories.json')) as json_file:
data_catg = json.load(json_file)
path_json_for_targeter = os.path.join(root, f"train{year}.json")
with open(path_json_for_targeter) as json_file:
data_for_targeter = json.load(json_file)
targeter = {}
indexer = 0
for elem in data_for_targeter['annotations']:
king = []
king.append(data_catg[int(elem['category_id'])][category])
if king[0] not in targeter.keys():
targeter[king[0]] = indexer
indexer += 1
self.nb_classes = len(targeter)
self.samples = []
for elem in data['images']:
cut = elem['file_name'].split('/')
target_current = int(cut[2])
path_current = os.path.join(root, cut[0], cut[2], cut[3])
categors = data_catg[target_current]
target_current_true = targeter[categors[category]]
self.samples.append((path_current, target_current_true))
# __getitem__ and __len__ inherited from ImageFolder
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(
args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'FLOWERS':
root = os.path.join(args.data_path, 'train' if is_train else 'test')
dataset = datasets.ImageFolder(root, transform=transform)
if is_train:
dataset = torch.utils.data.ConcatDataset(
[dataset for _ in range(100)])
nb_classes = 102
elif args.data_set == 'INAT':
dataset = INatDataset(args.data_path, train=is_train, year=2018,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
elif args.data_set == 'INAT19':
dataset = INatDataset(args.data_path, train=is_train, year=2019,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
else:
raise NotImplementedError
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# This should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
)
if not resize_im:
# Replace RandomResizedCropAndInterpolation with RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * args.input_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)

101
util/engine.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from .losses import DistillationLoss
import util.utils as utils
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
clip_grad: float = 0,
clip_mode: str = 'norm',
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True):
model.train(set_training_mode)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(
window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
for samples, targets in metric_logger.log_every(
data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if True: # with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(samples, outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# This attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(
optimizer, 'is_second_order') and optimizer.is_second_order
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# Switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# Gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
print(output.mean().item(), output.std().item())
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

64
util/losses.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Implements the knowledge distillation loss
"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# Don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(
outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss

60
util/samplers.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
import torch.distributed as dist
import math
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(
math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(
len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# Subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch

280
util/utils.py Normal file
View File

@@ -0,0 +1,280 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import io
import os
import time
from collections import defaultdict, deque
import datetime
import torch
import torch.distributed as dist
import subprocess
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total],
dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def _load_checkpoint_for_ema(model_ema, checkpoint):
"""
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
"""
mem_file = io.BytesIO()
torch.save(checkpoint, mem_file)
mem_file.seek(0)
model_ema._load_checkpoint(mem_file)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
print('Using distributed mode: 1')
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
print('Using distributed mode: slurm')
print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']},"
f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}")
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def replace_batchnorm(net):
for child_name, child in net.named_children():
if hasattr(child, 'fuse'):
setattr(net, child_name, child.fuse())
elif isinstance(child, torch.nn.Conv2d):
child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
elif isinstance(child, torch.nn.BatchNorm2d):
setattr(net, child_name, torch.nn.Identity())
else:
replace_batchnorm(child)
def replace_layernorm(net):
import apex
for child_name, child in net.named_children():
if isinstance(child, torch.nn.LayerNorm):
setattr(net, child_name, apex.normalization.FusedLayerNorm(
child.weight.size(0)))
else:
replace_layernorm(child)