Skip to content

Commit a0a80fc

Browse files
author
Shaden Smith
authored
Pipeline parallelism example (deepspeedai#50)
1 parent 38f0952 commit a0a80fc

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

pipeline_parallelism/alexnet.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# Implementation of AlexNet for illustrative purposes. The train.py driver
3+
# can import AlexNet from here or directly from torchvision.
4+
#
5+
# Taken from torchvision.models.alexnet:
6+
# https://pytorch.org/docs/1.6.0/_modules/torchvision/models/alexnet.html#alexnet
7+
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
13+
class AlexNet(nn.Module):
14+
def __init__(self, num_classes=1000):
15+
super(AlexNet, self).__init__()
16+
self.features = nn.Sequential(
17+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
18+
nn.ReLU(inplace=True),
19+
nn.MaxPool2d(kernel_size=3, stride=2),
20+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
21+
nn.ReLU(inplace=True),
22+
nn.MaxPool2d(kernel_size=3, stride=2),
23+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
24+
nn.ReLU(inplace=True),
25+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
26+
nn.ReLU(inplace=True),
27+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
28+
nn.ReLU(inplace=True),
29+
nn.MaxPool2d(kernel_size=3, stride=2),
30+
)
31+
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
32+
self.classifier = nn.Sequential(
33+
nn.Dropout(),
34+
nn.Linear(256 * 6 * 6, 4096),
35+
nn.ReLU(inplace=True),
36+
nn.Dropout(),
37+
nn.Linear(4096, 4096),
38+
nn.ReLU(inplace=True),
39+
nn.Linear(4096, num_classes),
40+
)
41+
42+
def forward(self, x):
43+
x = self.features(x)
44+
x = self.avgpool(x)
45+
x = torch.flatten(x, 1)
46+
x = self.classifier(x)
47+
return x
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"train_batch_size" : 256,
3+
"train_micro_batch_size_per_gpu" : 8,
4+
5+
"optimizer": {
6+
"type": "Adam",
7+
"params": {
8+
"lr": 0.001,
9+
"betas": [
10+
0.9,
11+
0.999
12+
],
13+
"eps": 1e-8
14+
}
15+
},
16+
17+
"steps_per_print" : 10,
18+
"wall_clock_breakdown" : false
19+
}

pipeline_parallelism/run.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
3+
deepspeed train.py --deepspeed_config=ds_config.json -p 2 --steps=200

pipeline_parallelism/train.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
5+
import torch
6+
import torch.distributed as dist
7+
8+
import torchvision
9+
import torchvision.transforms as transforms
10+
from torchvision.models import AlexNet
11+
from torchvision.models import vgg19
12+
13+
import deepspeed
14+
from deepspeed.pipe import PipelineModule
15+
from deepspeed.utils import RepeatingLoader
16+
17+
18+
def cifar_trainset(local_rank, dl_path='/tmp/cifar10-data'):
19+
transform = transforms.Compose([
20+
transforms.Resize(256),
21+
transforms.CenterCrop(224),
22+
transforms.ToTensor(),
23+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24+
])
25+
26+
# Ensure only one rank downloads.
27+
# Note: if the download path is not on a shared filesytem, remove the semaphore
28+
# and switch to args.local_rank
29+
dist.barrier()
30+
if local_rank != 0:
31+
dist.barrier()
32+
trainset = torchvision.datasets.CIFAR10(root=dl_path,
33+
train=True,
34+
download=True,
35+
transform=transform)
36+
if local_rank == 0:
37+
dist.barrier()
38+
return trainset
39+
40+
41+
def get_args():
42+
parser = argparse.ArgumentParser(description='CIFAR')
43+
parser.add_argument('--local_rank',
44+
type=int,
45+
default=-1,
46+
help='local rank passed from distributed launcher')
47+
parser.add_argument('-s',
48+
'--steps',
49+
type=int,
50+
default=100,
51+
help='quit after this many steps')
52+
parser.add_argument('-p',
53+
'--pipeline-parallel-size',
54+
type=int,
55+
default=2,
56+
help='pipeline parallelism')
57+
parser.add_argument('--backend',
58+
type=str,
59+
default='nccl',
60+
help='distributed backend')
61+
parser.add_argument('--seed', type=int, default=1138, help='PRNG seed')
62+
parser = deepspeed.add_config_arguments(parser)
63+
args = parser.parse_args()
64+
return args
65+
66+
67+
def train_base(args):
68+
torch.manual_seed(args.seed)
69+
70+
# VGG also works :-)
71+
#net = vgg19(num_classes=10)
72+
net = AlexNet(num_classes=10)
73+
74+
trainset = cifar_trainset(args.local_rank)
75+
76+
engine, _, dataloader, __ = deepspeed.initialize(
77+
args=args,
78+
model=net,
79+
model_parameters=[p for p in net.parameters() if p.requires_grad],
80+
training_data=trainset)
81+
82+
dataloader = RepeatingLoader(dataloader)
83+
data_iter = iter(dataloader)
84+
85+
rank = dist.get_rank()
86+
gas = engine.gradient_accumulation_steps()
87+
88+
criterion = torch.nn.CrossEntropyLoss()
89+
90+
total_steps = args.steps * engine.gradient_accumulation_steps()
91+
step = 0
92+
for micro_step in range(total_steps):
93+
batch = next(data_iter)
94+
inputs = batch[0].to(engine.device)
95+
labels = batch[1].to(engine.device)
96+
97+
outputs = engine(inputs)
98+
loss = criterion(outputs, labels)
99+
engine.backward(loss)
100+
engine.step()
101+
102+
if micro_step % engine.gradient_accumulation_steps() == 0:
103+
step += 1
104+
if rank == 0 and (step % 10 == 0):
105+
print(f'step: {step:3d} / {args.steps:3d} loss: {loss}')
106+
107+
108+
109+
def join_layers(vision_model):
110+
layers = [
111+
*vision_model.features,
112+
vision_model.avgpool,
113+
lambda x: torch.flatten(x, 1),
114+
*vision_model.classifier,
115+
]
116+
return layers
117+
118+
119+
def train_pipe(args, part='parameters'):
120+
torch.manual_seed(args.seed)
121+
deepspeed.runtime.utils.set_random_seed(args.seed)
122+
123+
#
124+
# Build the model
125+
#
126+
127+
# VGG also works :-)
128+
#net = vgg19(num_classes=10)
129+
net = AlexNet(num_classes=10)
130+
net = PipelineModule(layers=join_layers(net),
131+
loss_fn=torch.nn.CrossEntropyLoss(),
132+
num_stages=args.pipeline_parallel_size,
133+
partition_method=part,
134+
activation_checkpoint_interval=0)
135+
136+
trainset = cifar_trainset(args.local_rank)
137+
138+
engine, _, _, _ = deepspeed.initialize(
139+
args=args,
140+
model=net,
141+
model_parameters=[p for p in net.parameters() if p.requires_grad],
142+
training_data=trainset)
143+
144+
for step in range(args.steps):
145+
loss = engine.train_batch()
146+
147+
148+
if __name__ == '__main__':
149+
args = get_args()
150+
151+
torch.cuda.set_device(args.local_rank)
152+
dist.init_process_group(backend=args.backend)
153+
154+
if args.pipeline_parallel_size == 0:
155+
train_base(args)
156+
else:
157+
train_pipe(args)

0 commit comments

Comments
 (0)