|
| 1 | +""" |
| 2 | +(prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage |
| 3 | +================================================================================= |
| 4 | +
|
| 5 | +GPUDirect Storage enables a direct data path for direct memory access transfers |
| 6 | +between GPU memory and storage, avoiding a bounce buffer through the CPU. |
| 7 | +
|
| 8 | +In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around |
| 9 | +the `cuFile APIs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_ |
| 10 | +that can be used with ``torch.Tensor`` to achieve improved I/O performance. |
| 11 | +
|
| 12 | +In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with |
| 13 | +checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem. |
| 14 | +
|
| 15 | +.. grid:: 2 |
| 16 | +
|
| 17 | + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 18 | + :class-card: card-prerequisites |
| 19 | +
|
| 20 | + * Understand how to use the ``torch.cuda.gds`` APIs in conjunction with |
| 21 | + checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem |
| 22 | + |
| 23 | + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 24 | + :class-card: card-prerequisites |
| 25 | +
|
| 26 | + * PyTorch v.2.7.0 or later |
| 27 | + * GPUDirect Storage must be installed per |
| 28 | + `the documentation <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_ |
| 29 | + * Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage. |
| 30 | +""" |
| 31 | + |
| 32 | +################################################################################ |
| 33 | +# Using GPUDirect Storage with ``torch.save`` and ``torch.load`` |
| 34 | +# ------------------------------------------------------------------------------------ |
| 35 | +# GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using |
| 36 | +# ``torch.utils.serialization.config.save.storage_alignment``: |
| 37 | + |
| 38 | +import torch |
| 39 | +from torch.utils.serialization import config as serialization_config |
| 40 | + |
| 41 | +serialization_config.save.storage_alignment = 4096 |
| 42 | + |
| 43 | +################################################################################ |
| 44 | +# The steps involved in the process are as follows: |
| 45 | +# * Write the checkpoint file without any actual data. This reserves the space on disk. |
| 46 | +# * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``. |
| 47 | +# * Use ``GDSFile`` to write the appropriate data at these offsets. |
| 48 | +# |
| 49 | +# Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context |
| 50 | +# manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage`` |
| 51 | +# in the state dictionary, space will be reserved within the checkpoint for the storage bytes. |
| 52 | + |
| 53 | +import torch.nn as nn |
| 54 | + |
| 55 | +m = nn.Linear(5, 10, device='cuda') |
| 56 | +sd = m.state_dict() |
| 57 | + |
| 58 | +with torch.serialization.skip_data(): |
| 59 | + torch.save(sd, "checkpoint.pt") |
| 60 | + |
| 61 | +################################################################################ |
| 62 | +# We can get the offsets that each storage should be written to within the checkpoint by loading under |
| 63 | +# a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device) |
| 64 | +# information about the tensor but does not have any storage bytes. The following snippet will not materialize |
| 65 | +# any data but will tag each ``FakeTensor`` with the offset within the checkpoint that |
| 66 | +# corresponds to the tensor. |
| 67 | +# |
| 68 | +# If you are continuously saving the same state dictionary during training, you |
| 69 | +# would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to |
| 70 | +# be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps |
| 71 | +# ``cuFileBufRegister`` to register the storages as GDS buffers. |
| 72 | +# |
| 73 | +# Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API, |
| 74 | +# so no synchronization is needed afterwards. |
| 75 | + |
| 76 | + |
| 77 | +import os |
| 78 | +from torch._subclasses.fake_tensor import FakeTensorMode |
| 79 | + |
| 80 | +with FakeTensorMode() as mode: |
| 81 | + fake_sd = torch.load("checkpoint.pt") |
| 82 | + |
| 83 | +for k, v in fake_sd.items(): |
| 84 | + print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}") |
| 85 | + |
| 86 | +f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR) |
| 87 | + |
| 88 | +for k, v in sd.items(): |
| 89 | + offset = fake_sd[k].untyped_storage()._checkpoint_offset |
| 90 | + # save_storage is a wrapper around `cuFileWrite` |
| 91 | + f.save_storage(v.untyped_storage(), offset) |
| 92 | + |
| 93 | + |
| 94 | +################################################################################ |
| 95 | +# We verify correctness of the saved checkpoint by ``torch.load`` and comparing. |
| 96 | + |
| 97 | +sd_loaded = torch.load("checkpoint.pt") |
| 98 | +for k, v in sd_loaded.items(): |
| 99 | + assert torch.equal(v, sd[k]) |
| 100 | + |
| 101 | +################################################################################ |
| 102 | +# The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context |
| 103 | +# manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be |
| 104 | +# created but their storages will be empty (as if the tensors were created via ``torch.empty``). |
| 105 | + |
| 106 | +with torch.serialization.skip_data(): |
| 107 | + sd_loaded = torch.load("checkpoint.pt") |
| 108 | + |
| 109 | +################################################################################ |
| 110 | +# We once again use the ``FakeTensorMode`` to get the checkpoint offsets and |
| 111 | +# ascertain that the loaded checkpoint is the same as the saved checkpoint. |
| 112 | +# |
| 113 | +# Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage`` |
| 114 | +# binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards. |
| 115 | + |
| 116 | +for k, v in sd_loaded.items(): |
| 117 | + assert not torch.equal(v, sd[k]) |
| 118 | + offset = fake_sd[k].untyped_storage()._checkpoint_offset |
| 119 | + # load_storage is a wrapper around `cuFileRead` |
| 120 | + f.load_storage(v.untyped_storage(), offset) |
| 121 | + |
| 122 | +for k, v in sd_loaded.items(): |
| 123 | + assert torch.equal(v, sd[k]) |
| 124 | + |
| 125 | +del f |
| 126 | +########################################################## |
| 127 | +# Conclusion |
| 128 | +# ========== |
| 129 | +# |
| 130 | +# In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs |
| 131 | +# in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please |
| 132 | +# file an issue in the PyTorch GitHub repo if you have any feedback. |
0 commit comments