Pytorch中torch.nn.DataParallel负载均衡问题

本文探讨了在Pytorch中使用DataParallel进行多卡训练时遇到的负载不均衡问题,尤其是在处理GAN训练时,由于某些特定操作与DDP不兼容。通过详细分析DataParallel的工作原理和内存占用,提出了一种改进方案,旨在提高GPU利用率。文章提供了负载均衡的代码实现和使用示例。

1. 问题概述

现在Pytorc下进行多卡训练主流的是采用torch.nn.parallel.DistributedDataParallel()(DDP)方法,但是在一些特殊的情况下这样的方法就使用不了了,特别是在进行与GAN相关的训练的时候,假如使用的损失函数是 WGAN-GP(LP),DRAGAN,那么其中会用到基于梯度的惩罚,其使用到的函数为torch.autograd.grad(),但是很不幸的是在实验的过程中该函数使用DDP会报错:

File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: derivative for batch_norm_backward_elemt is not implemented

那么需要并行(单机多卡)计算那么就只能使用torch.nn.DataParallel()了,但是也带来另外一个问题那就是负载极其不均衡,使用这个并行计算方法会在主GPU上占据较多的现存,而其它的GPU显存则只占用了一部分,这样就使得无法再继续增大batchsize了,下图就是这种方式进行计算,整个数据流的路线:在这里插入图片描述
可以在上图中看到输入数据计算和损失计算过程中都会存在数据汇总的情况,这就难免使得主卡的显存爆掉,为了解决这样的问题一个思想就是其网络前向、计算损失的过程都采用并行的方式进行,其流程如下:
在这里插入图片描述

这样就可以解决显卡利用率不高的问题,下面给出一些可以参考的负载均衡代码:

  1. PyTorch-Encoding
  2. Balanced-DataParallel
  3. parallel.py

2. 代码实现

基于上述内容中的工作,这里将这个的并行过程汇集到一个文件里面,这样可以很方便将其当做是模块使用。

##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

torch_ver = torch.__version__[:3]

__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']

def allreduce(*inputs):
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
    """
    return AllReduce.apply(*inputs)

class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs]
                 for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *inputs):
        inputs = [i.data for i in inputs]
        inputs = [inputs[i:i + ctx.num_inputs]
                 for i in range(0, len(inputs), ctx.num_inputs)]
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])

class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
        inputs = sorted(inputs, key=lambda i: i.get_device())
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, gradOutput):
        return Broadcast.apply(ctx.target_gpus, gradOutput)


class DataParallelModel(DataParallel):
    """Implements data parallelism at the module level.

    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the
    batch dimension.
    In the forward pass, the module is replicated on each device,
    and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
    :class:`encoding.parallel.DataParallelCriterion`.

    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).

    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> y = net(x)
    """
    def gather(self, outputs, output_device):
        return outputs

    def replicate(self, module, device_ids):
        modules = super(DataParallelModel, self).replicate(module, device_ids)
        return modules

    def forward(self, inputs, **kwargs):
        if kwargs.get('parallel', False):
            kwargs.pop('parallel', None)  # this key is unexpected
            if isinstance(inputs, torch.Tensor):
                return super().forward(inputs, **kwargs)
            replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
            outputs = self.parallel_apply(replicas, inputs, kwargs if kwargs else None)
            return self.gather(outputs, self.output_device)
        else:  # not using parallel or evaluating
            return self.module(inputs)


class my_DataParallelCriterion(DataParallel):
    """
    Calculate loss in multiple-GPUs, which balance the memory usage for
    Semantic Segmentation.

    The targets are splitted across the specified devices by chunking in
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.

    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*

    Example::

        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
        >>> y = net(x)
        >>> loss = criterion(y, target)
    """
    def forward(self, inputs, *targets, **kwargs):
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)

        is_target_scattered = kwargs.get('is_target_scattered', False)
        kwargs.pop('is_target_scattered', None)  # this key is unexpected

        if not is_target_scattered:
            targets, kwargs = self.scatter(targets, kwargs, self.device_ids)

        if len(self.device_ids) == 1:
            if is_target_scattered:
                targets = (targets,)
                kwargs = (kwargs,)
            return self.module(inputs, *targets[0], **kwargs[0])

        if is_target_scattered:
            targets = targets[0]

        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = _criterion_parallel_apply(replicas, inputs, targets)
        return Reduce.apply(*outputs) / len(outputs)


def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    if torch_ver != "0.3":
        grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        #import pdb;pdb.set_trace()
        if torch_ver != "0.3":
            torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            if not isinstance(input, tuple):
                input = (input,)
            if not isinstance(target, tuple):
                target = (target,)
            with torch.cuda.device(device):
                output = module(*(input + target), **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
                   for i, (module, input, target, kwargs, device) in
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs

3. 使用示例

from my_parallel import DataParallelModel, my_DataParallelCriterion

class CriterionCE(nn.Module):
    def __init__(self, ignore_index=255, use_weight=True, reduce=True):
        super(CriterionCE, self).__init__()
        self.ignore_index = ignore_index
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
        if not reduce:
            print("disabled the reduce.")

    def forward(self, preds, target):
        h, w = target.size(1), target.size(2)

        scale_pred = F.upsample(input=preds, size=(h, w), mode='bilinear', align_corners=True)
        loss = self.criterion(scale_pred, target)

        return loss

# define model
model = DemoModel()
model = DataParallelModel(model)
model = model.cuda()

# define loss
criterion = CriterionCE()
criterion = my_DataParallelCriterion(criterion).cuda()

# train step
for i, (images, target) in enumerate(train_loader):
    images = images.cuda()
    target = target.cuda()
    output = model(images, parallel=True)
    optimizer.zero_grad()
    # is_target_scattered target是否来自己与不同的GPU上
    loss = criterion(student_output, target, is_target_scattered=False)
    loss.backward()
    optimizer.step()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值