1. 问题概述
现在Pytorc下进行多卡训练主流的是采用torch.nn.parallel.DistributedDataParallel()(DDP)方法,但是在一些特殊的情况下这样的方法就使用不了了,特别是在进行与GAN相关的训练的时候,假如使用的损失函数是 WGAN-GP(LP),DRAGAN,那么其中会用到基于梯度的惩罚,其使用到的函数为torch.autograd.grad(),但是很不幸的是在实验的过程中该函数使用DDP会报错:
File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: derivative for batch_norm_backward_elemt is not implemented
那么需要并行(单机多卡)计算那么就只能使用torch.nn.DataParallel()了,但是也带来另外一个问题那就是负载极其不均衡,使用这个并行计算方法会在主GPU上占据较多的现存,而其它的GPU显存则只占用了一部分,这样就使得无法再继续增大batchsize了,下图就是这种方式进行计算,整个数据流的路线:
可以在上图中看到输入数据计算和损失计算过程中都会存在数据汇总的情况,这就难免使得主卡的显存爆掉,为了解决这样的问题一个思想就是其网络前向、计算损失的过程都采用并行的方式进行,其流程如下:

这样就可以解决显卡利用率不高的问题,下面给出一些可以参考的负载均衡代码:
2. 代码实现
基于上述内容中的工作,这里将这个的并行过程汇集到一个文件里面,这样可以很方便将其当做是模块使用。
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
return AllReduce.apply(*inputs)
class AllReduce(Function):
@staticmethod
def forward(ctx, num_inputs, *inputs):
ctx.num_inputs = num_inputs
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
inputs = [inputs[i:i + num_inputs]
for i in range(0, len(inputs), num_inputs)]
# sort before reduce sum
inputs = sorted(inputs, key=lambda i: i[0].get_device())
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return tuple([t for tensors in outputs for t in tensors])
@staticmethod
def backward(ctx, *inputs):
inputs = [i.data for i in inputs]
inputs = [inputs[i:i + ctx.num_inputs]
for i in range(0, len(inputs), ctx.num_inputs)]
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
inputs = sorted(inputs, key=lambda i: i.get_device())
return comm.reduce_add(inputs)
@staticmethod
def backward(ctx, gradOutput):
return Broadcast.apply(ctx.target_gpus, gradOutput)
class DataParallelModel(DataParallel):
"""Implements data parallelism at the module level.
This container parallelizes the application of the given module by
splitting the input across the specified devices by chunking in the
batch dimension.
In the forward pass, the module is replicated on each device,
and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
Note that the outputs are not gathered, please use compatible
:class:`encoding.parallel.DataParallelCriterion`.
The batch size should be larger than the number of GPUs used. It should
also be an integer multiple of the number of GPUs so that each chunk is
the same size (so that each GPU processes the same number of samples).
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> y = net(x)
"""
def gather(self, outputs, output_device):
return outputs
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
return modules
def forward(self, inputs, **kwargs):
if kwargs.get('parallel', False):
kwargs.pop('parallel', None) # this key is unexpected
if isinstance(inputs, torch.Tensor):
return super().forward(inputs, **kwargs)
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs if kwargs else None)
return self.gather(outputs, self.output_device)
else: # not using parallel or evaluating
return self.module(inputs)
class my_DataParallelCriterion(DataParallel):
"""
Calculate loss in multiple-GPUs, which balance the memory usage for
Semantic Segmentation.
The targets are splitted across the specified devices by chunking in
the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
Reference:
Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
Amit Agrawal. “Context Encoding for Semantic Segmentation.
*The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
Example::
>>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
>>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
>>> y = net(x)
>>> loss = criterion(y, target)
"""
def forward(self, inputs, *targets, **kwargs):
if not self.device_ids:
return self.module(inputs, *targets, **kwargs)
is_target_scattered = kwargs.get('is_target_scattered', False)
kwargs.pop('is_target_scattered', None) # this key is unexpected
if not is_target_scattered:
targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
if len(self.device_ids) == 1:
if is_target_scattered:
targets = (targets,)
kwargs = (kwargs,)
return self.module(inputs, *targets[0], **kwargs[0])
if is_target_scattered:
targets = targets[0]
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets)
return Reduce.apply(*outputs) / len(outputs)
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
if torch_ver != "0.3":
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
#import pdb;pdb.set_trace()
if torch_ver != "0.3":
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
if not isinstance(input, tuple):
input = (input,)
if not isinstance(target, tuple):
target = (target,)
with torch.cuda.device(device):
output = module(*(input + target), **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target,
kwargs, device),)
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
3. 使用示例
from my_parallel import DataParallelModel, my_DataParallelCriterion
class CriterionCE(nn.Module):
def __init__(self, ignore_index=255, use_weight=True, reduce=True):
super(CriterionCE, self).__init__()
self.ignore_index = ignore_index
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
if not reduce:
print("disabled the reduce.")
def forward(self, preds, target):
h, w = target.size(1), target.size(2)
scale_pred = F.upsample(input=preds, size=(h, w), mode='bilinear', align_corners=True)
loss = self.criterion(scale_pred, target)
return loss
# define model
model = DemoModel()
model = DataParallelModel(model)
model = model.cuda()
# define loss
criterion = CriterionCE()
criterion = my_DataParallelCriterion(criterion).cuda()
# train step
for i, (images, target) in enumerate(train_loader):
images = images.cuda()
target = target.cuda()
output = model(images, parallel=True)
optimizer.zero_grad()
# is_target_scattered target是否来自己与不同的GPU上
loss = criterion(student_output, target, is_target_scattered=False)
loss.backward()
optimizer.step()
本文探讨了在Pytorch中使用DataParallel进行多卡训练时遇到的负载不均衡问题,尤其是在处理GAN训练时,由于某些特定操作与DDP不兼容。通过详细分析DataParallel的工作原理和内存占用,提出了一种改进方案,旨在提高GPU利用率。文章提供了负载均衡的代码实现和使用示例。
5925

被折叠的 条评论
为什么被折叠?



