Skip to content

Commit 01746bf

Browse files
authored
fix checkpointing links, small grammar update (#1651)
Fixes a link to point to blogpost instead of internal and a small grammar fix
1 parent 947ab9e commit 01746bf

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

_posts/2024-06-12-reducing-checkpointing-times.md

+13-13
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ title: "Reducing Model Checkpointing Times by Over 10x with PyTorch Distributed
44
author: "Meta: Lucas Pasqualin, Less Wright, Iris Zhang (PyTorch), Chien-Chin Huang; IBM Research: Swaminathan Sundararaman, Saransh Gupta, Raghu Ganti"
55
---
66

7-
**Summary:** With PyTorch distributed’s new asynchronous checkpointing feature, developed with feedback from IBM, we show how IBM Research Team is able to implement and reduce effective checkpointing time by a factor of 10-20x. Example: 7B model ‘down time’ for a checkpoint goes from an average of 148.8 seconds to 6.3 seconds, or 23.62x faster.
7+
**Summary:** With PyTorch distributed’s new asynchronous checkpointing feature, developed with feedback from IBM, we show how IBM Research Team is able to implement and reduce effective checkpointing time by a factor of 10-20x. Example: 7B model ‘down time’ for a checkpoint goes from an average of 148.8 seconds to 6.3 seconds, or 23.62x faster.
88

9-
This directly translates into either more net training progress for every given 24 hour period while continuing to robustly checkpoint or more frequent checkpoints to shorten recovery window/time.
9+
This directly translates into either more net training progress for every given 24 hour period while continuing to robustly checkpoint or more frequent checkpoints to shorten recovery window/time.
1010

1111
In this note, we showcase the usage code and architecture that makes asynchronous checkpointing possible, along with timing results verified by IBM’s Research team.
1212

@@ -18,49 +18,49 @@ Model checkpointing is a vital part of large model training, but checkpointing i
1818

1919
Thus, the inherent tension between robustness to failures vs training progress plays out as a tradeoff, but now with asynchronous checkpointing, PyTorch Distributed is able to significantly reduce this tension and enable frequent checkpoint with minimal impact to the overall training time.
2020

21-
For background, it was almost exactly [a year ago](https://fb.workplace.com/notes/1332777457600375) that we showcased how distributed checkpointing had massively sped up checkpointing times from the original torch.save() functionality. As IBM Research had noted, torch.save could take up to 30 minutes to checkpoint a single 11B model (PyTorch 1.13).
21+
For background, it was almost exactly [a year ago](https://pytorch.org/blog/performant-distributed-checkpointing/) that we showcased how distributed checkpointing had massively sped up checkpointing times from the original torch.save() functionality. As IBM Research had noted, torch.save could take up to 30 minutes to checkpoint a single 11B model (PyTorch 1.13).
2222

2323
With advancements in distributed checkpointing, checkpoints could be done in under 4 minutes for up to 30B model sizes.
2424

25-
With asynchronous checkpointing, the training time lost due to checkpointing now moves to under 30 seconds, and often as short as 6 seconds.
25+
With asynchronous checkpointing, the training time lost due to checkpointing now moves to under 30 seconds, and often as short as 6 seconds.
2626

2727
To be clear, asynchronous checkpointing does not compress the actual serialization checkpointing time as the previous update showcased. Rather it moves the final checkpointing process off the critical path (to cpu threads) to allow GPU training to continue while finalizing the checkpoint under separate threads.
2828

29-
However, to the user, the effect is nearly the same in that down time for training due to checkpointing is substantially reduced, in many cases by 10x or even 20x.
29+
However, to the user, the effect is nearly the same in that down time for training due to checkpointing is substantially reduced, in many cases by 10x or even 20x.
3030

3131

3232
![Async Dist Checkpointing](/assets/images/reducing-checkpointing-times/fg2.png){:style="width:100%"}
3333

3434

35-
As the above speedup chart shows, asynchronous checkpointing produces a 10x to 23x further improvement over the previous large improvements from a year ago.
35+
As the above speedup chart shows, asynchronous checkpointing produces a 10x to 23x further improvement over the previous large improvements from a year ago.
3636

3737

38-
## How does Asynchronous Checkpointing work?
38+
## How does Asynchronous Checkpointing work?
3939

4040
Asynchronous checkpointing modularizes the checkpointing process into two parts rather than one monolithic process. The first phase copies the data from each gpu/rank from GPU to CPU. This is the visible downtime to the user and can take from 6 - 14 seconds for 7B-13B model sizes. The second phase asynchronously copies the data from CPU memory to disk to persist the checkpoint.
4141

4242

4343
Once data is copied to CPU in the first phase, the GPU is free to immediately resume training. Hence with asynchronous checkpointing the downtime for checkpointing is simply the time needed to copy over the latest model states to CPU.
4444

45-
At the same time that training resumes, non-blocking CPU threads work with the freshly arrived data in memory to complete the full checkpointing/serialization process to disk (i.e. persistent save).
45+
At the same time that training resumes, non-blocking CPU threads work with the freshly arrived data in memory to complete the full checkpointing/serialization process to disk (i.e. persistent save).
4646

4747
![flow diagram](/assets/images/reducing-checkpointing-times/fg3.png){:style="width:100%"}
4848

4949

5050

51-
Note that PyTorch’s Distributed Checkpointer relies on collective communication calls to per-rank metadata necessary to optimize saves, as well as a final synchronization which marks checkpointing as complete and makes the action atomic. This can interfere with distributed training (as distributed training also relies upon similar calls to synchronize training across multiple GPUs) if the Checkpointing thread utilizes the same process group used for training.
51+
Note that PyTorch’s Distributed Checkpointer relies on collective communication calls for per-rank metadata necessary to optimize saves, as well as a final synchronization which marks checkpointing as complete and makes the action atomic. This can interfere with distributed training (as distributed training also relies upon similar calls to synchronize training across multiple GPUs) if the Checkpointing thread utilizes the same process group used for training.
5252

5353
Specifically, a race condition between the calls could potentially cause training and asynch checkpointing save threads to wait on collective calls at the same time, resulting in a true collective hang.
5454

55-
We avoided this scenario by initializing a separate process group for async checkpointing. This separates the checkpointing collectives into their own logical process group, which thus ensures it will not interfere with collective calls in the main training threads.
55+
We avoided this scenario by initializing a separate process group for async checkpointing. This separates the checkpointing collectives into their own logical process group, which thus ensures it will not interfere with collective calls in the main training threads.
5656

5757

58-
## How do I use Asynchronous Checkpointing in my training?
58+
## How do I use Asynchronous Checkpointing in my training?
5959

6060
Usage of Asynchronous checkpointing is relatively straightforward. Using the latest nightly version of PyTorch, you will want to initialize your process group with both nccl and gloo. Gloo is required for the cpu threads portion.
6161

6262
From there, create a duplicate process group which the asynchronous checkpointing will utilize.
63-
Then train as usual but at the point when you want to checkpoint, use the asynchronous save api, passing in the states to save, the checkpoint id and the checkpoint process group.
63+
Then train as usual but at the point when you want to checkpoint, use the asynchronous save api, passing in the states to save, the checkpoint id and the checkpoint process group.
6464

6565
![Code snippet](/assets/images/reducing-checkpointing-times/fg4.png){:style="width:100%"}
6666

@@ -81,4 +81,4 @@ The last frontier - zero overhead checkpointing where even the < 30 seconds i
8181

8282
This would effectively move large model training to where checkpointing has no disruption or downtime enabling both more robustness (as checkpoints could be taken more frequently) and faster training progress due to no downtime for checkpointing.
8383

84-
Source code link: [https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py)
84+
Source code link: [https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py](https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py)

0 commit comments

Comments
 (0)