Skip to content

Commit 2007b52

Browse files
awan-10lekurilemolly-smith
authored
Add benchmarks (deepspeedai#254)
Co-authored-by: Lev Kurilenko <[email protected]> Co-authored-by: Molly Smith <[email protected]>
1 parent 60be252 commit 2007b52

File tree

18 files changed

+1428
-1
lines changed

18 files changed

+1428
-1
lines changed

benchmarks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
The new home for DeepSpeed benchmarks. TODO: Move DS benchmarks to this repo.
1+
All benchmarks that use the DeepSpeed library are maintained in this folder. We welcome contributions in this space!

benchmarks/communication/README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# The DeepSpeed Communication Benchmarking Suite
2+
3+
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
4+
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
5+
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed
6+
7+
To run benchmarks, there are two options:
8+
9+
1. Run a single communication operation:
10+
11+
For example, run with a single large message size (calculated to barely fit within GPU mem):
12+
<pre>
13+
deepspeed all_reduce.py
14+
</pre>
15+
16+
Scan across message sizes:
17+
<pre>
18+
deepspeed all_reduce.py --scan
19+
</pre>
20+
21+
Benchmark pure PyTorch distributed comms (without importing or using DeepSpeed) with MPI
22+
<pre>
23+
mpirun -np 16 --hostfile ${HOSTFILE} -x LD_LIBRARY_PATH -x PATH -x LD_PRELOAD python all_reduce.py --scan --dist="torch"
24+
</pre>
25+
26+
or Slurm
27+
<pre>
28+
srun -n 16 python all_reduce.py --scan --dist="torch"
29+
</pre>
30+
31+
32+
2. Run all available communication benchmarks:
33+
34+
<pre>
35+
deepspeed run_all.py
36+
</pre>
37+
38+
Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bw-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op.
39+
40+
<pre>
41+
usage: ds_bench [-h] [--local_rank LOCAL_RANK] [--trials TRIALS] [--warmups WARMUPS] [--maxsize MAXSIZE] [--async-op] [--bw-unit {Gbps,GBps}] [--backend {nccl}] [--dist {deepspeed,torch}] [--scan] [--raw] [--all-reduce] [--all-gather] [--all-to-all]
42+
[--pt2pt] [--broadcast] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug]
43+
44+
optional arguments:
45+
-h, --help show this help message and exit
46+
--local_rank LOCAL_RANK
47+
--trials TRIALS Number of timed iterations
48+
--warmups WARMUPS Number of warmup (non-timed) iterations
49+
--maxsize MAXSIZE Max message size as a power of 2
50+
--async-op Enables non-blocking communication
51+
--bw-unit {Gbps,GBps}
52+
--backend {nccl} Communication library to use
53+
--dist {deepspeed,torch}
54+
Distributed DL framework to use
55+
--scan Enables scanning all message sizes
56+
--raw Print the message size and latency without units
57+
--all-reduce Run all_reduce
58+
--all-gather Run all_gather
59+
--all-to-all Run all_to_all
60+
--pt2pt Run pt2pt
61+
--broadcast Run broadcast
62+
--dtype DTYPE PyTorch tensor dtype
63+
--mem-factor MEM_FACTOR
64+
Proportion of max available GPU memory to use for single-size evals
65+
--debug Enables all_to_all debug prints
66+
</pre>
67+
68+
Note that `ds_bench` is a pre-packaged wrapper around `run_all.py`. Users can pass the same arguments as well:
69+
70+
<pre>
71+
<path to deepspeed>/bin/ds_bench --scan --trials=10
72+
</pre>
73+
74+
Finally, users can choose specific communication operations to run in `run_all.py` or `ds_bench` by passing them as arguments (all operations are run by default). For example:
75+
76+
<pre>
77+
deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast
78+
</pre>
79+
80+
81+
# Adding Communication Benchmarks
82+
83+
To add new communication benchmarks, follow this general procedure:
84+
85+
1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template)
86+
2. Add a new bw formula in `utils.get_bw`, a new maximum tensor element formula in `utils.max_numel`, and a new arg in `utils.benchmark_parser`
87+
3. Replace comm op calls in new file with find-replace
88+
4. Find a good default `mem_factor` for use in `run_<collective>_single()` function
89+
5. Add new comm op to `run_all.py`
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
'''Copyright The Microsoft DeepSpeed Team'''
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
import sys, os, time
8+
9+
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
10+
sys.path.append(COMMS_BENCH_DIR)
11+
12+
from communication.utils import *
13+
from communication.constants import *
14+
from deepspeed.accelerator import get_accelerator
15+
from deepspeed.comm import TorchBackend
16+
17+
18+
# Run all_gather and print metrics
19+
def timed_all_gather(input, output, args):
20+
if args.dist == 'torch':
21+
import torch.distributed as dist
22+
23+
all_gather_func = TorchBackend.get_all_gather_function()
24+
elif args.dist == 'deepspeed':
25+
import deepspeed.comm as dist
26+
27+
all_gather_func = dist.allgather_fn
28+
29+
sync_all()
30+
# Warmups, establish connections, etc.
31+
for i in range(args.warmups):
32+
all_gather_func(output, input, group=None, async_op=args.async_op)
33+
sync_all()
34+
35+
# time the actual comm op trials times and average it
36+
pre = time.perf_counter()
37+
for i in range(args.trials):
38+
all_gather_func(output, input, group=None, async_op=args.async_op)
39+
sync_all()
40+
duration = time.perf_counter() - pre
41+
42+
# maintain and clean performance data
43+
avg_duration = duration / args.trials
44+
size = input.element_size() * input.nelement()
45+
tput, busbw = get_bw('all_gather', size, avg_duration, args)
46+
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
47+
desc = f'{input.nelement()}x{input.element_size()}'
48+
49+
if not args.raw:
50+
size = convert_size(size)
51+
52+
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
53+
54+
55+
def run_all_gather(local_rank, args):
56+
if args.dist == 'torch':
57+
import torch.distributed as dist
58+
elif args.dist == 'deepspeed':
59+
import deepspeed.comm as dist
60+
61+
# Prepare benchmark header
62+
print_header(args, 'all_gather')
63+
global_rank = dist.get_rank()
64+
world_size = dist.get_world_size()
65+
66+
if args.scan:
67+
# Create list of message sizes
68+
M_LIST = []
69+
for x in (2**p for p in range(1, args.maxsize)):
70+
M_LIST.append(x)
71+
72+
sync_all()
73+
# loop over various tensor sizes
74+
for M in M_LIST:
75+
global_rank = dist.get_rank()
76+
try:
77+
mat = torch.ones(world_size, M,
78+
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
79+
sync_all()
80+
input = ((mat.mul_(float(global_rank))).view(-1))
81+
# Delete original mat to avoid OOM
82+
del mat
83+
get_accelerator().empty_cache()
84+
output = torch.zeros(input.nelement() * world_size,
85+
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
86+
except RuntimeError as e:
87+
if 'out of memory' in str(e):
88+
if dist.get_rank() == 0:
89+
print('WARNING: Ran out of GPU memory. Exiting comm op.')
90+
sync_all()
91+
break
92+
else:
93+
raise e
94+
sync_all()
95+
timed_all_gather(input, output, args)
96+
else:
97+
# all_gather_into_tensor saves memory
98+
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
99+
mem_factor = args.mem_factor + 0.2
100+
else:
101+
mem_factor = args.mem_factor
102+
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
103+
sync_all()
104+
elements_per_gpu = max_numel(comm_op='all_gather',
105+
dtype=getattr(torch, args.dtype),
106+
mem_factor=mem_factor,
107+
local_rank=local_rank,
108+
args=args)
109+
try:
110+
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
111+
args.dtype)).to(get_accelerator().device_name(local_rank))
112+
# multiply each GPU's tensor by the rank to ease debugging
113+
input = ((mat.mul_(float(global_rank))).view(-1))
114+
# Delete original mat to avoid OOM
115+
del mat
116+
get_accelerator().empty_cache()
117+
output = torch.zeros(elements_per_gpu * world_size,
118+
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
119+
except RuntimeError as e:
120+
if 'out of memory' in str(e):
121+
if dist.get_rank() == 0:
122+
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
123+
sync_all()
124+
return
125+
else:
126+
raise e
127+
128+
sync_all()
129+
timed_all_gather(input, output, args)
130+
131+
132+
if __name__ == "__main__":
133+
args = benchmark_parser().parse_args()
134+
rank = args.local_rank
135+
init_processes(local_rank=rank, args=args)
136+
run_all_gather(local_rank=rank, args=args)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
import sys, os, time
8+
9+
COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
10+
sys.path.append(COMMS_BENCH_DIR)
11+
12+
from communication.utils import *
13+
from communication.constants import *
14+
from deepspeed.accelerator import get_accelerator
15+
16+
17+
def timed_all_reduce(input, args):
18+
if args.dist == 'torch':
19+
import torch.distributed as dist
20+
elif args.dist == 'deepspeed':
21+
import deepspeed.comm as dist
22+
23+
sync_all()
24+
# Warmups, establish connections, etc.
25+
for i in range(args.warmups):
26+
dist.all_reduce(input, async_op=args.async_op)
27+
sync_all()
28+
29+
# time the actual comm op trials times and average it
30+
pre = time.perf_counter()
31+
for i in range(args.trials):
32+
dist.all_reduce(input, async_op=args.async_op)
33+
sync_all()
34+
duration = time.perf_counter() - pre
35+
36+
# maintain and clean performance data
37+
avg_duration = duration / args.trials
38+
size = input.element_size() * input.nelement()
39+
n = dist.get_world_size()
40+
tput, busbw = get_bw('all_reduce', size, avg_duration, args)
41+
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
42+
desc = f'{input.nelement()}x{input.element_size()}'
43+
44+
if not args.raw:
45+
size = convert_size(size)
46+
47+
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
48+
49+
50+
def run_all_reduce(local_rank, args):
51+
if args.dist == 'torch':
52+
import torch.distributed as dist
53+
elif args.dist == 'deepspeed':
54+
import deepspeed.comm as dist
55+
56+
# Prepare benchmark header
57+
print_header(args, 'all_reduce')
58+
59+
world_size = dist.get_world_size()
60+
global_rank = dist.get_rank()
61+
62+
if args.scan:
63+
M_LIST = []
64+
for x in (2**p for p in range(1, args.maxsize)):
65+
M_LIST.append(x)
66+
67+
sync_all()
68+
# loop over various tensor sizes
69+
for M in M_LIST:
70+
global_rank = dist.get_rank()
71+
try:
72+
mat = torch.ones(world_size, M,
73+
dtype=getattr(torch, args.dtype)).to(get_accelerator().device_name(local_rank))
74+
sync_all()
75+
input = ((mat.mul_(float(global_rank))).view(-1))
76+
except RuntimeError as e:
77+
if 'out of memory' in str(e):
78+
if dist.get_rank() == 0:
79+
print('WARNING: Ran out of GPU memory. Exiting comm op.')
80+
sync_all()
81+
break
82+
else:
83+
raise e
84+
sync_all()
85+
timed_all_reduce(input, args)
86+
else:
87+
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
88+
# Don't need output tensor, so we double mem_factor
89+
elements_per_gpu = max_numel(comm_op='all_reduce',
90+
dtype=getattr(torch, args.dtype),
91+
mem_factor=args.mem_factor * 2,
92+
local_rank=local_rank,
93+
args=args)
94+
try:
95+
mat = torch.ones(elements_per_gpu, dtype=getattr(torch,
96+
args.dtype)).to(get_accelerator().device_name(local_rank))
97+
input = ((mat.mul_(float(global_rank))).view(-1))
98+
except RuntimeError as e:
99+
if 'out of memory' in str(e):
100+
if dist.get_rank() == 0:
101+
print('WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!')
102+
sync_all()
103+
return
104+
else:
105+
raise e
106+
sync_all()
107+
timed_all_reduce(input, args)
108+
109+
110+
if __name__ == "__main__":
111+
args = benchmark_parser().parse_args()
112+
rank = args.local_rank
113+
init_processes(local_rank=rank, args=args)
114+
run_all_reduce(local_rank=rank, args=args)

0 commit comments

Comments
 (0)