Initial release of SwiftFormer
This commit is contained in:
6
util/__init__.py
Normal file
6
util/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import util.utils as utils
|
||||
from .datasets import build_dataset
|
||||
from .engine import train_one_epoch, evaluate
|
||||
from .losses import DistillationLoss
|
||||
from .samplers import RASampler
|
||||
|
||||
120
util/datasets.py
Normal file
120
util/datasets.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.datasets.folder import ImageFolder, default_loader
|
||||
import torch
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.data import create_transform
|
||||
|
||||
|
||||
class INatDataset(ImageFolder):
|
||||
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, category='name',
|
||||
loader=default_loader):
|
||||
super().__init__(root, transform, target_transform, loader)
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
self.target_transform = target_transform
|
||||
self.year = year
|
||||
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
|
||||
path_json = os.path.join(
|
||||
root, f'{"train" if train else "val"}{year}.json')
|
||||
with open(path_json) as json_file:
|
||||
data = json.load(json_file)
|
||||
|
||||
with open(os.path.join(root, 'categories.json')) as json_file:
|
||||
data_catg = json.load(json_file)
|
||||
|
||||
path_json_for_targeter = os.path.join(root, f"train{year}.json")
|
||||
|
||||
with open(path_json_for_targeter) as json_file:
|
||||
data_for_targeter = json.load(json_file)
|
||||
|
||||
targeter = {}
|
||||
indexer = 0
|
||||
for elem in data_for_targeter['annotations']:
|
||||
king = []
|
||||
king.append(data_catg[int(elem['category_id'])][category])
|
||||
if king[0] not in targeter.keys():
|
||||
targeter[king[0]] = indexer
|
||||
indexer += 1
|
||||
self.nb_classes = len(targeter)
|
||||
|
||||
self.samples = []
|
||||
for elem in data['images']:
|
||||
cut = elem['file_name'].split('/')
|
||||
target_current = int(cut[2])
|
||||
path_current = os.path.join(root, cut[0], cut[2], cut[3])
|
||||
|
||||
categors = data_catg[target_current]
|
||||
target_current_true = targeter[categors[category]]
|
||||
self.samples.append((path_current, target_current_true))
|
||||
|
||||
# __getitem__ and __len__ inherited from ImageFolder
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
|
||||
if args.data_set == 'CIFAR':
|
||||
dataset = datasets.CIFAR100(
|
||||
args.data_path, train=is_train, transform=transform)
|
||||
nb_classes = 100
|
||||
elif args.data_set == 'IMNET':
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
nb_classes = 1000
|
||||
elif args.data_set == 'FLOWERS':
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'test')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
if is_train:
|
||||
dataset = torch.utils.data.ConcatDataset(
|
||||
[dataset for _ in range(100)])
|
||||
nb_classes = 102
|
||||
elif args.data_set == 'INAT':
|
||||
dataset = INatDataset(args.data_path, train=is_train, year=2018,
|
||||
category=args.inat_category, transform=transform)
|
||||
nb_classes = dataset.nb_classes
|
||||
elif args.data_set == 'INAT19':
|
||||
dataset = INatDataset(args.data_path, train=is_train, year=2019,
|
||||
category=args.inat_category, transform=transform)
|
||||
nb_classes = dataset.nb_classes
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return dataset, nb_classes
|
||||
|
||||
|
||||
def build_transform(is_train, args):
|
||||
resize_im = args.input_size > 32
|
||||
if is_train:
|
||||
# This should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation=args.train_interpolation,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
)
|
||||
if not resize_im:
|
||||
# Replace RandomResizedCropAndInterpolation with RandomCrop
|
||||
transform.transforms[0] = transforms.RandomCrop(
|
||||
args.input_size, padding=4)
|
||||
return transform
|
||||
|
||||
t = []
|
||||
if resize_im:
|
||||
size = int((256 / 224) * args.input_size)
|
||||
t.append(
|
||||
# to maintain same ratio w.r.t. 224 images
|
||||
transforms.Resize(size, interpolation=3),
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
||||
return transforms.Compose(t)
|
||||
101
util/engine.py
Normal file
101
util/engine.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Train and eval functions used in main.py
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from timm.data import Mixup
|
||||
from timm.utils import accuracy, ModelEma
|
||||
|
||||
from .losses import DistillationLoss
|
||||
import util.utils as utils
|
||||
|
||||
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
||||
device: torch.device, epoch: int, loss_scaler,
|
||||
clip_grad: float = 0,
|
||||
clip_mode: str = 'norm',
|
||||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
||||
set_training_mode=True):
|
||||
model.train(set_training_mode)
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(
|
||||
window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
print_freq = 100
|
||||
|
||||
for samples, targets in metric_logger.log_every(
|
||||
data_loader, print_freq, header):
|
||||
samples = samples.to(device, non_blocking=True)
|
||||
targets = targets.to(device, non_blocking=True)
|
||||
|
||||
if mixup_fn is not None:
|
||||
samples, targets = mixup_fn(samples, targets)
|
||||
|
||||
if True: # with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(samples, outputs, targets)
|
||||
|
||||
loss_value = loss.item()
|
||||
|
||||
if not math.isfinite(loss_value):
|
||||
print("Loss is {}, stopping training".format(loss_value))
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# This attribute is added by timm on one optimizer (adahessian)
|
||||
is_second_order = hasattr(
|
||||
optimizer, 'is_second_order') and optimizer.is_second_order
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters(), create_graph=is_second_order)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
metric_logger.update(loss=loss_value)
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
# gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, device):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
# Switch to evaluation mode
|
||||
model.eval()
|
||||
|
||||
for images, target in metric_logger.log_every(data_loader, 10, header):
|
||||
images = images.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
# Compute output
|
||||
with torch.cuda.amp.autocast():
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
|
||||
batch_size = images.shape[0]
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
||||
|
||||
# Gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
||||
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
||||
print(output.mean().item(), output.std().item())
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
64
util/losses.py
Normal file
64
util/losses.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Implements the knowledge distillation loss
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DistillationLoss(torch.nn.Module):
|
||||
"""
|
||||
This module wraps a standard criterion and adds an extra knowledge distillation loss by
|
||||
taking a teacher model prediction and using it as additional supervision.
|
||||
"""
|
||||
|
||||
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
|
||||
distillation_type: str, alpha: float, tau: float):
|
||||
super().__init__()
|
||||
self.base_criterion = base_criterion
|
||||
self.teacher_model = teacher_model
|
||||
assert distillation_type in ['none', 'soft', 'hard']
|
||||
self.distillation_type = distillation_type
|
||||
self.alpha = alpha
|
||||
self.tau = tau
|
||||
|
||||
def forward(self, inputs, outputs, labels):
|
||||
"""
|
||||
Args:
|
||||
inputs: The original inputs that are feed to the teacher model
|
||||
outputs: the outputs of the model to be trained. It is expected to be
|
||||
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
|
||||
in the first position and the distillation predictions as the second output
|
||||
labels: the labels for the base criterion
|
||||
"""
|
||||
outputs_kd = None
|
||||
if not isinstance(outputs, torch.Tensor):
|
||||
# assume that the model outputs a tuple of [outputs, outputs_kd]
|
||||
outputs, outputs_kd = outputs
|
||||
base_loss = self.base_criterion(outputs, labels)
|
||||
if self.distillation_type == 'none':
|
||||
return base_loss
|
||||
|
||||
if outputs_kd is None:
|
||||
raise ValueError("When knowledge distillation is enabled, the model is "
|
||||
"expected to return a Tuple[Tensor, Tensor] with the output of the "
|
||||
"class_token and the dist_token")
|
||||
# Don't backprop throught the teacher
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher_model(inputs)
|
||||
|
||||
if self.distillation_type == 'soft':
|
||||
T = self.tau
|
||||
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# with slight modifications
|
||||
distillation_loss = F.kl_div(
|
||||
F.log_softmax(outputs_kd / T, dim=1),
|
||||
F.log_softmax(teacher_outputs / T, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (T * T) / outputs_kd.numel()
|
||||
elif self.distillation_type == 'hard':
|
||||
distillation_loss = F.cross_entropy(
|
||||
outputs_kd, teacher_outputs.argmax(dim=1))
|
||||
|
||||
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
|
||||
return loss
|
||||
60
util/samplers.py
Normal file
60
util/samplers.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||||
with repeated augmentation.
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.num_selected_samples = int(math.floor(
|
||||
len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# Deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# Add extra samples to make it evenly divisible
|
||||
indices = [ele for ele in indices for i in range(3)]
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# Subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[:self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
280
util/utils.py
Normal file
280
util/utils.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import subprocess
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total],
|
||||
dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
log_msg = [
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append('max mem: {memory:.0f}')
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
||||
"""
|
||||
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
||||
"""
|
||||
mem_file = io.BytesIO()
|
||||
torch.save(checkpoint, mem_file)
|
||||
mem_file.seek(0)
|
||||
model_ema._load_checkpoint(mem_file)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
args.dist_url = 'env://'
|
||||
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
|
||||
print('Using distributed mode: 1')
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
addr = subprocess.getoutput(
|
||||
'scontrol show hostname {} | head -n1'.format(node_list))
|
||||
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['LOCAL_SIZE'] = str(num_gpus)
|
||||
args.dist_url = 'env://'
|
||||
args.world_size = ntasks
|
||||
args.rank = proc_id
|
||||
args.gpu = proc_id % num_gpus
|
||||
print('Using distributed mode: slurm')
|
||||
print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']},"
|
||||
f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}")
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def replace_batchnorm(net):
|
||||
for child_name, child in net.named_children():
|
||||
if hasattr(child, 'fuse'):
|
||||
setattr(net, child_name, child.fuse())
|
||||
elif isinstance(child, torch.nn.Conv2d):
|
||||
child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
|
||||
elif isinstance(child, torch.nn.BatchNorm2d):
|
||||
setattr(net, child_name, torch.nn.Identity())
|
||||
else:
|
||||
replace_batchnorm(child)
|
||||
|
||||
|
||||
def replace_layernorm(net):
|
||||
import apex
|
||||
for child_name, child in net.named_children():
|
||||
if isinstance(child, torch.nn.LayerNorm):
|
||||
setattr(net, child_name, apex.normalization.FusedLayerNorm(
|
||||
child.weight.size(0)))
|
||||
else:
|
||||
replace_layernorm(child)
|
||||
Reference in New Issue
Block a user