commit 574907f49b72c975634673e8cf5b6b9bfc7c1114 Author: amshaker Date: Sun Mar 26 23:31:59 2023 +0400 Initial release of SwiftFormer diff --git a/README.md b/README.md new file mode 100644 index 0000000..a2f2aa7 --- /dev/null +++ b/README.md @@ -0,0 +1,142 @@ +# SwiftFormer +### **SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications** + +[Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ), +[Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra), +[Hanoona Rasheed](https://scholar.google.com/citations?user=yhDdEuEAAAAJ&hl=en&authuser=1&oi=sra), +[Salman Khan](https://salman-h-khan.github.io), +[Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), +and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en) + + +[![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](arxiv_link) + + + +## :rocket: News +* **(Mar 27, 2023):** Classification training and evaluation codes along with pre-trained models are released. + +
+ +

+
+ Comparison of our SwiftFormer Models with state-of-the-art on ImgeNet-1K. The latency is measured on iPhone 14 Neural Engine (iOS 16). +

+ +

+
+

+

+ Comparison with different self-attention modules. (a) is a typical self-attention. (b) is the transpose self-attention, where the self-attention operation is applied across channel feature dimensions (d×d) instead of the spatial dimension (n×n). (c) is the separable self-attention of MobileViT-v2, it uses element-wise operations to compute the context vector from the interactions of Q and K matrices. Then, the context vector is multiplied by V matrix to produce the final output. (d) Our proposed efficient additive self-attention. Here, the query matrix is multiplied by learnable weights and pooled to produce global queries. Then, the matrix K is element-wise multiplied by the broadcasted global queries, resulting the global context representation. +

+ +
+ + Abstract + +Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8~ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2. +
+ +
+ + + +## Classification on ImageNet-1K + +### Models + +| Model | Top-1 accuracy | #params | GMACs | Latency | Ckpt | CoreML| +|:---------------|:----:|:---:|:--:|:--:|:--:|:--:| +| SwiftFormer-XS | 75.7% | 3.5M | 0.4G | 0.7ms | [XS](https://drive.google.com/file/d/15Ils-U96pQePXQXx2MpmaI-yAceFAr2x/view?usp=sharing) | [XS](https://drive.google.com/file/d/1tZVxtbtAZoLLoDc5qqoUGulilksomLeK/view?usp=sharing) | +| SwiftFormer-S | 78.5% | 6.1M | 1.0G | 0.8ms | [S](https://drive.google.com/file/d/1_0eWwgsejtS0bWGBQS3gwAtYjXdPRGlu/view?usp=sharing) | [S](https://drive.google.com/file/d/13EOCZmtvbMR2V6UjezSZnbBz2_-59Fva/view?usp=sharing) | +| SwiftFormer-L1 | 80.9% | 12.1M | 1.6G | 1.1ms | [L1](https://drive.google.com/file/d/1jlwrwWQ0SQzDRc5adtWIwIut5d1g9EsM/view?usp=sharing) | [L1](https://drive.google.com/file/d/1c3VUsi4q7QQ2ykXVS2d4iCRL478fWF3e/view?usp=sharing) | +| SwiftFormer-L3 | 83.0% | 26.5M | 4.0G | 1.9ms | [L3](https://drive.google.com/file/d/1ypBcjx04ShmPYRhhjBRubiVjbExUgSa7/view?usp=sharing) | [L3](https://drive.google.com/file/d/1svahgIjh7da781jHOHjX58mtzCzYXSsJ/view?usp=sharing) | + + +## Detection and Segmentation Qualitative Results + +

+
+

+

+
+

+ +## Latency Measurement + +The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/). + +## ImageNet + +### Prerequisites +`conda` virtual environment is recommended. + +```shell +conda create --name=swiftformer python=3.9 +conda activate swiftformer + +pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 +pip install timm +``` + +### Data preparation + +Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively: +``` +|-- /path/to/imagenet/ + |-- train + |-- val +``` + +### Single machine multi-GPU training + +We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP). + +To train SwiftFormer models on an 8-GPU machine: + +``` +sh dist_train.sh /path/to/imagenet 8 +``` + +Note: specify which model command you want to run in the script. To reproduce the results of the paper, use 16-GPU machine with batch-size of 128 or 8-GPU machine with batch size of 256. Auto Augmentation, CutMix, MixUp are disabled for SwiftFormer-XS only. + +### Multi-node training + +On a Slurm-managed cluster, multi-node training can be launched as + +``` +sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS +``` + +Note: specify slurm specific paramters in `slurm_train.sh` script. + +### Testing + +We provide an example test script `dist_test.sh` using PyTorch distributed data parallel (DDP). +For example, to test SwiftFormer-XS on an 8-GPU machine: + +``` +sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth +``` + +## Citation +if you use our work, please consider citing us: +```BibTeX +@article{Shaker2023SwiftFormer, + title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications}, + author={Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz}, + journal={arXiv preprint arXiv:X.X}, + year={2023} +} +``` + +## Contact: +If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae. + + +## Acknowledgement +Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation. + +## Our Related Works + +- EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications, CADL'22, ECCV. [Paper](https://arxiv.org/abs/2206.10589) | [Code](https://github.com/mmaaz60/EdgeNeXt). diff --git a/dist_test.sh b/dist_test.sh new file mode 100644 index 0000000..d644366 --- /dev/null +++ b/dist_test.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +IMAGENET_PATH=$1 +MODEL=$2 +CHECKPOINT=$3 +nGPUs=$4 + +python -m torch.distributed.launch --master_addr="127.0.0.1" --master_port=1234 --nproc_per_node=$nGPUs --use_env main.py --model "$MODEL" \ +--resume $CHECKPOINT --eval \ +--data-path "$IMAGENET_PATH" \ +--output_dir SwiftFormer_test diff --git a/dist_train.sh b/dist_train.sh new file mode 100644 index 0000000..0f81d00 --- /dev/null +++ b/dist_train.sh @@ -0,0 +1,21 @@ + +#!/usr/bin/env bash + +IMAGENET_PATH=$1 +nGPUs=$2 + +## SwiftFormer-XS +python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_XS --aa="" --mixup 0 --cutmix 0 --data-path "$IMAGENET_PATH" \ +--output_dir SwiftFormer_XS_results + +## SwiftFormer-S +python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_S --data-path "$IMAGENET_PATH" \ +--output_dir SwiftFormer_S_results + +## SwiftFormer-L1 +python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L1 --data-path "$IMAGENET_PATH" \ +--output_dir SwiftFormer_L1_results + +## SwiftFormer-L3 +python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L3 --data-path "$IMAGENET_PATH" \ +--output_dir SwiftFormer_L3_results diff --git a/images/Swiftformer_performance.png b/images/Swiftformer_performance.png new file mode 100644 index 0000000..ff2406d Binary files /dev/null and b/images/Swiftformer_performance.png differ diff --git a/images/attention_comparison.pdf b/images/attention_comparison.pdf new file mode 100644 index 0000000..3ab4736 Binary files /dev/null and b/images/attention_comparison.pdf differ diff --git a/images/attentions_comparison.png b/images/attentions_comparison.png new file mode 100644 index 0000000..f65c0cb Binary files /dev/null and b/images/attentions_comparison.png differ diff --git a/images/detection_seg.png b/images/detection_seg.png new file mode 100644 index 0000000..966cba5 Binary files /dev/null and b/images/detection_seg.png differ diff --git a/images/semantic_seg.png b/images/semantic_seg.png new file mode 100644 index 0000000..987c173 Binary files /dev/null and b/images/semantic_seg.png differ diff --git a/main.py b/main.py new file mode 100644 index 0000000..5d56a3e --- /dev/null +++ b/main.py @@ -0,0 +1,412 @@ +import argparse +import datetime +import numpy as np +import time +import torch +import torch.backends.cudnn as cudnn +import json +from pathlib import Path + +from timm.data import Mixup +from timm.models import create_model +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.scheduler import create_scheduler +from timm.optim import create_optimizer +from timm.utils import NativeScaler, get_state_dict, ModelEma + +from util import * +from models import * + + +def get_args_parser(): + parser = argparse.ArgumentParser( + 'SwiftFormer training and evaluation script', add_help=False) + parser.add_argument('--batch-size', default=128, type=int) + parser.add_argument('--epochs', default=300, type=int) + + # Model parameters + parser.add_argument('--model', default='SwiftFormer_XS', type=str, metavar='MODEL', + help='Name of model to train') + parser.add_argument('--input-size', default=224, + type=int, help='images input size') + + parser.add_argument('--model-ema', action='store_true') + parser.add_argument( + '--no-model-ema', action='store_false', dest='model_ema') + parser.set_defaults(model_ema=True) + parser.add_argument('--model-ema-decay', type=float, + default=0.99996, help='') + parser.add_argument('--model-ema-force-cpu', + action='store_true', default=False, help='') + + # Optimizer parameters + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--clip-mode', type=str, default='agc', + help='Gradient clipping mode. One of ("norm", "value", "agc")') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=0.025, + help='weight decay (default: 0.025)') + # Learning rate schedule parameters + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr', type=float, default=2e-3, metavar='LR', + help='learning rate (default: 2e-3)') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + + # Augmentation parameters + parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + \ + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') + parser.add_argument('--train-interpolation', type=str, default='bicubic', + help='Training interpolation (random, bilinear, bicubic default: "bicubic")') + + parser.add_argument('--repeated-aug', action='store_true') + parser.add_argument('--no-repeated-aug', + action='store_false', dest='repeated_aug') + parser.set_defaults(repeated_aug=True) + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') + parser.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') + parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # Distillation parameters + parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', + help='Name of teacher model to train (default: "regnety_160"') + parser.add_argument('--teacher-path', type=str, + default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth') + parser.add_argument('--distillation-type', default='hard', + choices=['none', 'soft', 'hard'], type=str, help="") + parser.add_argument('--distillation-alpha', + default=0.5, type=float, help="") + parser.add_argument('--distillation-tau', default=1.0, type=float, help="") + + # * Finetuning params + parser.add_argument('--finetune', default='', + help='finetune from checkpoint') + + # Dataset parameters + parser.add_argument('--data-path', default='./imagenet', type=str, + help='dataset path') + parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], + type=str, help='Image Net dataset path') + parser.add_argument('--inat-category', default='name', + choices=['kingdom', 'phylum', 'class', 'order', + 'supercategory', 'family', 'genus', 'name'], + type=str, help='semantic granularity') + + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist-eval', action='store_true', + default=False, help='Enabling distributed evaluation') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', + help='') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + return parser + + +def main(args): + utils.init_distributed_mode(args) + + print(args) + + if args.distillation_type != 'none' and args.finetune and not args.eval: + raise NotImplementedError( + "Finetuning with distillation not yet supported") + + device = torch.device(args.device) + + # Fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) + dataset_val, _ = build_dataset(is_train=False, args=args) + + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + if args.repeated_aug: + sampler_train = RASampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=int(1.5 * args.batch_size), + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + print(f"Creating model: {args.model}") + model = create_model( + args.model, + num_classes=args.nb_classes, + distillation=(args.distillation_type != 'none'), + pretrained=args.eval, + fuse=args.eval, + ) + + if args.finetune: + if args.finetune.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.finetune, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.finetune, map_location='cpu') + + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias', + 'head_dist.weight', 'head_dist.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + model.load_state_dict(checkpoint_model, strict=False) + + model.to(device) + + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but + # before SyncBN and DDP wrapper + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume='') + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu]) + model_without_ddp = model.module + n_parameters = sum(p.numel() + for p in model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + # better not to scale up lr for AdamW optimizer + # linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 + # args.lr = linear_scaled_lr + + optimizer = create_optimizer(args, model_without_ddp) + loss_scaler = NativeScaler() + + lr_scheduler, _ = create_scheduler(args, optimizer) + + if args.mixup > 0.: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + teacher_model = None + if args.distillation_type != 'none': + assert args.teacher_path, 'need to specify teacher-path when using distillation' + print(f"Creating teacher model: {args.teacher_model}") + teacher_model = create_model( + args.teacher_model, + pretrained=False, + num_classes=args.nb_classes, + global_pool='avg', + ) + if args.teacher_path.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.teacher_path, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.teacher_path, map_location='cpu') + teacher_model.load_state_dict(checkpoint['model']) + teacher_model.to(device) + teacher_model.eval() + + # Wrap the criterion in our custom DistillationLoss, which + # just dispatches to the original criterion if args.distillation_type is + # 'none' + criterion = DistillationLoss( + criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau + ) + + output_dir = Path(args.output_dir) + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 + if args.model_ema: + utils._load_checkpoint_for_ema( + model_ema, checkpoint['model_ema']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + if args.eval: + test_stats = evaluate(data_loader_val, model, device) + print( + f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + return + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, args.clip_mode, model_ema, mixup_fn, + set_training_mode=args.finetune == '' # keep in eval mode during finetuning + ) + + lr_scheduler.step(epoch) + if args.output_dir: + checkpoint_paths = [output_dir / 'checkpoint.pth'] + for checkpoint_path in checkpoint_paths: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'model_ema': get_state_dict(model_ema), + 'scaler': loss_scaler.state_dict(), + 'args': args, + }, checkpoint_path) + + if epoch % 20 == 19: + test_stats = evaluate(data_loader_val, model, device) + print( + f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") + max_accuracy = max(max_accuracy, test_stats["acc1"]) + print(f'Max accuracy: {max_accuracy:.2f}%') + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and utils.is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'SwiftFormer training and evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..b07c562 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3 diff --git a/models/swiftformer.py b/models/swiftformer.py new file mode 100644 index 0000000..7243fae --- /dev/null +++ b/models/swiftformer.py @@ -0,0 +1,507 @@ +""" +SwiftFormer +""" +import os +import copy +import torch +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, trunc_normal_ +from timm.models.registry import register_model +from timm.models.layers.helpers import to_2tuple +import einops + +SwiftFormer_width = { + 'XS': [48, 56, 112, 220], + 'S': [48, 64, 168, 224], + 'l1': [48, 96, 192, 384], + 'l3': [64, 128, 320, 512], +} + +SwiftFormer_depth = { + 'XS': [3, 3, 6, 4], + 'S': [3, 3, 9, 6], + 'l1': [4, 3, 10, 5], + 'l3': [4, 4, 12, 6], +} + +CoreMLConversion = False + + +def stem(in_chs, out_chs): + """ + Stem Layer that is implemented by two layers of conv. + Output: sequence of layers with final shape of [B, C, H/4, W/4] + """ + return nn.Sequential( + nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_chs // 2), + nn.ReLU(), + nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_chs), + nn.ReLU(), ) + + +class Embedding(nn.Module): + """ + Patch Embedding that is implemented by a layer of conv. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H/stride, W/stride] + """ + + def __init__(self, patch_size=16, stride=16, padding=0, + in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d): + super().__init__() + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, + stride=stride, padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + + +class ConvEncoder(nn.Module): + """ + Implementation of ConvEncoder with 3*3 and 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + Output: tensor with shape [B, C, H, W] + """ + + def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) + self.norm = nn.BatchNorm2d(dim) + self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) + self.act = nn.GELU() + self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.use_layer_scale: + x = input + self.drop_path(self.layer_scale * x) + else: + x = input + self.drop_path(x) + return x + + +class Mlp(nn.Module): + """ + Implementation of MLP layer with 1*1 convolutions. + Input: tensor with shape [B, C, H, W] + Output: tensor with shape [B, C, H, W] + """ + + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm1 = nn.BatchNorm2d(in_features) + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.norm1(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class EfficientAdditiveAttnetion(nn.Module): + """ + Efficient Additive Attention module for SwiftFormer. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H, W] + """ + + def __init__(self, in_dims=512, token_dim=256, num_heads=2): + super().__init__() + + self.to_query = nn.Linear(in_dims, token_dim * num_heads) + self.to_key = nn.Linear(in_dims, token_dim * num_heads) + + self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1)) + self.scale_factor = token_dim ** -0.5 + self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads) + self.final = nn.Linear(token_dim * num_heads, token_dim) + + def forward(self, x): + query = self.to_query(x) + key = self.to_key(x) + + if not CoreMLConversion: + # torch.nn.functional.normalize is not supported by the ANE of iPhone devices. + # Using this layer improves the accuracy by ~0.1-0.2% + query = torch.nn.functional.normalize(query, dim=-1) + key = torch.nn.functional.normalize(key, dim=-1) + + query_weight = query @ self.w_g + A = query_weight * self.scale_factor + + A = A.softmax(dim=-1) + + G = torch.sum(A * query, dim=1) + + G = einops.repeat( + G, "b d -> b repeat d", repeat=key.shape[1] + ) + + out = self.Proj(G * key) + query + + out = self.final(out) + + return out + + +class SwiftFormerLocalRepresentation(nn.Module): + """ + Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H, W] + """ + + def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) + self.norm = nn.BatchNorm2d(dim) + self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1) + self.act = nn.GELU() + self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1) + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.use_layer_scale: + x = input + self.drop_path(self.layer_scale * x) + else: + x = input + self.drop_path(x) + return x + + +class SwiftFormerEncoder(nn.Module): + """ + SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H, W] + """ + + def __init__(self, dim, mlp_ratio=4., + act_layer=nn.GELU, + drop=0., drop_path=0., + use_layer_scale=True, layer_scale_init_value=1e-5): + + super().__init__() + + self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0., + use_layer_scale=True) + self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1) + self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + + def forward(self, x): + x = self.local_representation(x) + B, C, H, W = x.shape + if self.use_layer_scale: + x = x + self.drop_path( + self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute( + 0, 3, 1, 2)) + x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) + + else: + x = x + self.drop_path( + self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2)) + x = x + self.drop_path(self.linear(x)) + return x + + +def Stage(dim, index, layers, mlp_ratio=4., + act_layer=nn.GELU, + drop_rate=.0, drop_path_rate=0., + use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1): + """ + Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks. + Input: tensor in shape [B, C, H, W] + Output: tensor in shape [B, C, H, W] + """ + blocks = [] + + for block_idx in range(layers[index]): + block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + + if layers[index] - block_idx <= vit_num: + blocks.append(SwiftFormerEncoder( + dim, mlp_ratio=mlp_ratio, + act_layer=act_layer, drop_path=block_dpr, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value)) + + else: + blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3)) + + blocks = nn.Sequential(*blocks) + return blocks + + +class SwiftFormer(nn.Module): + + def __init__(self, layers, embed_dims=None, + mlp_ratios=4, downsamples=None, + act_layer=nn.GELU, + num_classes=1000, + down_patch_size=3, down_stride=2, down_pad=1, + drop_rate=0., drop_path_rate=0., + use_layer_scale=True, layer_scale_init_value=1e-5, + fork_feat=False, + init_cfg=None, + pretrained=None, + vit_num=1, + distillation=True, + **kwargs): + super().__init__() + + if not fork_feat: + self.num_classes = num_classes + self.fork_feat = fork_feat + + self.patch_embed = stem(3, embed_dims[0]) + + network = [] + for i in range(len(layers)): + stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios, + act_layer=act_layer, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + use_layer_scale=use_layer_scale, + layer_scale_init_value=layer_scale_init_value, + vit_num=vit_num) + network.append(stage) + if i >= len(layers) - 1: + break + if downsamples[i] or embed_dims[i] != embed_dims[i + 1]: + # downsampling between two stages + network.append( + Embedding( + patch_size=down_patch_size, stride=down_stride, + padding=down_pad, + in_chans=embed_dims[i], embed_dim=embed_dims[i + 1] + ) + ) + + self.network = nn.ModuleList(network) + + if self.fork_feat: + # add a norm layer for each output + self.out_indices = [0, 2, 4, 6] + for i_emb, i_layer in enumerate(self.out_indices): + if i_emb == 0 and os.environ.get('FORK_LAST3', None): + layer = nn.Identity() + else: + layer = nn.BatchNorm2d(embed_dims[i_emb]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + else: + # Classifier head + self.norm = nn.BatchNorm2d(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 \ + else nn.Identity() + self.dist = distillation + if self.dist: + self.dist_head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 \ + else nn.Identity() + + # self.apply(self.cls_init_weights) + self.apply(self._init_weights) + + self.init_cfg = copy.deepcopy(init_cfg) + # load pre-trained model + if self.fork_feat and ( + self.init_cfg is not None or pretrained is not None): + self.init_weights() + + # init for mmdetection or mmsegmentation by loading + # imagenet pre-trained weights + def init_weights(self, pretrained=None): + logger = get_root_logger() + if self.init_cfg is None and pretrained is None: + logger.warn(f'No pre-trained weights for ' + f'{self.__class__.__name__}, ' + f'training start from scratch') + pass + else: + assert 'checkpoint' in self.init_cfg, f'Only support ' \ + f'specify `Pretrained` in ' \ + f'`init_cfg` in ' \ + f'{self.__class__.__name__} ' + if self.init_cfg is not None: + ckpt_path = self.init_cfg['checkpoint'] + elif pretrained is not None: + ckpt_path = pretrained + + ckpt = _load_checkpoint( + ckpt_path, logger=logger, map_location='cpu') + if 'state_dict' in ckpt: + _state_dict = ckpt['state_dict'] + elif 'model' in ckpt: + _state_dict = ckpt['model'] + else: + _state_dict = ckpt + + state_dict = _state_dict + missing_keys, unexpected_keys = \ + self.load_state_dict(state_dict, False) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm)): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_tokens(self, x): + outs = [] + for idx, block in enumerate(self.network): + x = block(x) + if self.fork_feat and idx in self.out_indices: + norm_layer = getattr(self, f'norm{idx}') + x_out = norm_layer(x) + outs.append(x_out) + if self.fork_feat: + return outs + return x + + def forward(self, x): + x = self.patch_embed(x) + x = self.forward_tokens(x) + if self.fork_feat: + # Output features of four stages for dense prediction + return x + + x = self.norm(x) + if self.dist: + cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1)) + if not self.training: + cls_out = (cls_out[0] + cls_out[1]) / 2 + else: + cls_out = self.head(x.mean(-2)) + # For image classification + return cls_out + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': 'head', + **kwargs + } + + +@register_model +def SwiftFormer_XS(pretrained=False, **kwargs): + model = SwiftFormer( + layers=SwiftFormer_depth['XS'], + embed_dims=SwiftFormer_width['XS'], + downsamples=[True, True, True, True], + vit_num=1, + **kwargs) + model.default_cfg = _cfg(crop_pct=0.9) + return model + + +@register_model +def SwiftFormer_S(pretrained=False, **kwargs): + model = SwiftFormer( + layers=SwiftFormer_depth['S'], + embed_dims=SwiftFormer_width['S'], + downsamples=[True, True, True, True], + vit_num=1, + **kwargs) + model.default_cfg = _cfg(crop_pct=0.9) + return model + + +@register_model +def SwiftFormer_L1(pretrained=False, **kwargs): + model = SwiftFormer( + layers=SwiftFormer_depth['l1'], + embed_dims=SwiftFormer_width['l1'], + downsamples=[True, True, True, True], + vit_num=1, + **kwargs) + model.default_cfg = _cfg(crop_pct=0.9) + return model + + +@register_model +def SwiftFormer_L3(pretrained=False, **kwargs): + model = SwiftFormer( + layers=SwiftFormer_depth['l3'], + embed_dims=SwiftFormer_width['l3'], + downsamples=[True, True, True, True], + vit_num=1, + **kwargs) + model.default_cfg = _cfg(crop_pct=0.9) + return model diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..dac9a80 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch==1.11.0+cu113 +torchvision==0.12.0+cu113 +timm==0.5.4 diff --git a/slurm_train.sh b/slurm_train.sh new file mode 100644 index 0000000..17c4650 --- /dev/null +++ b/slurm_train.sh @@ -0,0 +1,23 @@ +#!/bin/sh +#SBATCH --job-name=swiftformer +#SBATCH --partition=your_partition +#SBATCH --time=48:00:00 +#SBATCH --nodes=4 +#SBATCH --ntasks=16 +#SBATCH --cpus-per-task=16 +#SBATCH --gres=gpu:4 +#SBATCH --mem-per-cpu=8000 + +IMAGENET_PATH=$1 +MODEL=$2 + +srun python main.py --model "$MODEL" \ +--data-path "$IMAGENET_PATH" \ +--batch-size 128 \ +--epochs 300 \ +--aa="" --mixup 0 --cutmix 0 + + +## Note: Disable aa, mixup, and cutmix for SwiftFormer-XS only +## By default, this script requests total 16 GPUs on 4 nodes. The batch size per gpu is set to 128, +## tha sums to 128*16=2048 in total. \ No newline at end of file diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..9ad1376 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1,6 @@ +import util.utils as utils +from .datasets import build_dataset +from .engine import train_one_epoch, evaluate +from .losses import DistillationLoss +from .samplers import RASampler + diff --git a/util/datasets.py b/util/datasets.py new file mode 100644 index 0000000..0aa6ea6 --- /dev/null +++ b/util/datasets.py @@ -0,0 +1,120 @@ +import os +import json + +from torchvision import datasets, transforms +from torchvision.datasets.folder import ImageFolder, default_loader +import torch + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import create_transform + + +class INatDataset(ImageFolder): + def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, category='name', + loader=default_loader): + super().__init__(root, transform, target_transform, loader) + self.transform = transform + self.loader = loader + self.target_transform = target_transform + self.year = year + # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] + path_json = os.path.join( + root, f'{"train" if train else "val"}{year}.json') + with open(path_json) as json_file: + data = json.load(json_file) + + with open(os.path.join(root, 'categories.json')) as json_file: + data_catg = json.load(json_file) + + path_json_for_targeter = os.path.join(root, f"train{year}.json") + + with open(path_json_for_targeter) as json_file: + data_for_targeter = json.load(json_file) + + targeter = {} + indexer = 0 + for elem in data_for_targeter['annotations']: + king = [] + king.append(data_catg[int(elem['category_id'])][category]) + if king[0] not in targeter.keys(): + targeter[king[0]] = indexer + indexer += 1 + self.nb_classes = len(targeter) + + self.samples = [] + for elem in data['images']: + cut = elem['file_name'].split('/') + target_current = int(cut[2]) + path_current = os.path.join(root, cut[0], cut[2], cut[3]) + + categors = data_catg[target_current] + target_current_true = targeter[categors[category]] + self.samples.append((path_current, target_current_true)) + + # __getitem__ and __len__ inherited from ImageFolder + + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + + if args.data_set == 'CIFAR': + dataset = datasets.CIFAR100( + args.data_path, train=is_train, transform=transform) + nb_classes = 100 + elif args.data_set == 'IMNET': + root = os.path.join(args.data_path, 'train' if is_train else 'val') + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 1000 + elif args.data_set == 'FLOWERS': + root = os.path.join(args.data_path, 'train' if is_train else 'test') + dataset = datasets.ImageFolder(root, transform=transform) + if is_train: + dataset = torch.utils.data.ConcatDataset( + [dataset for _ in range(100)]) + nb_classes = 102 + elif args.data_set == 'INAT': + dataset = INatDataset(args.data_path, train=is_train, year=2018, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + elif args.data_set == 'INAT19': + dataset = INatDataset(args.data_path, train=is_train, year=2019, + category=args.inat_category, transform=transform) + nb_classes = dataset.nb_classes + else: + raise NotImplementedError + + return dataset, nb_classes + + +def build_transform(is_train, args): + resize_im = args.input_size > 32 + if is_train: + # This should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation=args.train_interpolation, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + ) + if not resize_im: + # Replace RandomResizedCropAndInterpolation with RandomCrop + transform.transforms[0] = transforms.RandomCrop( + args.input_size, padding=4) + return transform + + t = [] + if resize_im: + size = int((256 / 224) * args.input_size) + t.append( + # to maintain same ratio w.r.t. 224 images + transforms.Resize(size, interpolation=3), + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + return transforms.Compose(t) diff --git a/util/engine.py b/util/engine.py new file mode 100644 index 0000000..f25839f --- /dev/null +++ b/util/engine.py @@ -0,0 +1,101 @@ +""" +Train and eval functions used in main.py +""" +import math +import sys +from typing import Iterable, Optional + +import torch + +from timm.data import Mixup +from timm.utils import accuracy, ModelEma + +from .losses import DistillationLoss +import util.utils as utils + + +def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, + clip_grad: float = 0, + clip_mode: str = 'norm', + model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, + set_training_mode=True): + model.train(set_training_mode) + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue( + window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 100 + + for samples, targets in metric_logger.log_every( + data_loader, print_freq, header): + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + if True: # with torch.cuda.amp.autocast(): + outputs = model(samples) + loss = criterion(samples, outputs, targets) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + optimizer.zero_grad() + + # This attribute is added by timm on one optimizer (adahessian) + is_second_order = hasattr( + optimizer, 'is_second_order') and optimizer.is_second_order + loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, + parameters=model.parameters(), create_graph=is_second_order) + + torch.cuda.synchronize() + if model_ema is not None: + model_ema.update(model) + + metric_logger.update(loss=loss_value) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + # Switch to evaluation mode + model.eval() + + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + + # Compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + + # Gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) + print(output.mean().item(), output.std().item()) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/util/losses.py b/util/losses.py new file mode 100644 index 0000000..c332136 --- /dev/null +++ b/util/losses.py @@ -0,0 +1,64 @@ +""" +Implements the knowledge distillation loss +""" +import torch +from torch.nn import functional as F + + +class DistillationLoss(torch.nn.Module): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + + def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, + distillation_type: str, alpha: float, tau: float): + super().__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ['none', 'soft', 'hard'] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, torch.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == 'none': + return base_loss + + if outputs_kd is None: + raise ValueError("When knowledge distillation is enabled, the model is " + "expected to return a Tuple[Tensor, Tensor] with the output of the " + "class_token and the dist_token") + # Don't backprop throught the teacher + with torch.no_grad(): + teacher_outputs = self.teacher_model(inputs) + + if self.distillation_type == 'soft': + T = self.tau + # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # with slight modifications + distillation_loss = F.kl_div( + F.log_softmax(outputs_kd / T, dim=1), + F.log_softmax(teacher_outputs / T, dim=1), + reduction='sum', + log_target=True + ) * (T * T) / outputs_kd.numel() + elif self.distillation_type == 'hard': + distillation_loss = F.cross_entropy( + outputs_kd, teacher_outputs.argmax(dim=1)) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss diff --git a/util/samplers.py b/util/samplers.py new file mode 100644 index 0000000..84545d1 --- /dev/null +++ b/util/samplers.py @@ -0,0 +1,60 @@ +import torch +import torch.distributed as dist +import math + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU) + Heavily based on torch.utils.data.DistributedSampler + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int( + math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + self.num_selected_samples = int(math.floor( + len(self.dataset) // 256 * 256 / self.num_replicas)) + self.shuffle = shuffle + + def __iter__(self): + # Deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # Add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(3)] + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # Subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[:self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/util/utils.py b/util/utils.py new file mode 100644 index 0000000..70b3858 --- /dev/null +++ b/util/utils.py @@ -0,0 +1,280 @@ +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import io +import os +import time +from collections import defaultdict, deque +import datetime + +import torch +import torch.distributed as dist +import subprocess + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], + dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def _load_checkpoint_for_ema(model_ema, checkpoint): + """ + Workaround for ModelEma._load_checkpoint to accept an already-loaded object + """ + mem_file = io.BytesIO() + torch.save(checkpoint, mem_file) + mem_file.seek(0) + model_ema._load_checkpoint(mem_file) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) + print('Using distributed mode: 1') + elif 'SLURM_PROCID' in os.environ: + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + addr = subprocess.getoutput( + 'scontrol show hostname {} | head -n1'.format(node_list)) + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['LOCAL_SIZE'] = str(num_gpus) + args.dist_url = 'env://' + args.world_size = ntasks + args.rank = proc_id + args.gpu = proc_id % num_gpus + print('Using distributed mode: slurm') + print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," + f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def replace_batchnorm(net): + for child_name, child in net.named_children(): + if hasattr(child, 'fuse'): + setattr(net, child_name, child.fuse()) + elif isinstance(child, torch.nn.Conv2d): + child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0))) + elif isinstance(child, torch.nn.BatchNorm2d): + setattr(net, child_name, torch.nn.Identity()) + else: + replace_batchnorm(child) + + +def replace_layernorm(net): + import apex + for child_name, child in net.named_children(): + if isinstance(child, torch.nn.LayerNorm): + setattr(net, child_name, apex.normalization.FusedLayerNorm( + child.weight.size(0))) + else: + replace_layernorm(child)