Skip to content

Commit dd7f12e

Browse files
committed
Consolidate DeviceGuard into DeviceMesh header
Signed-off-by: Matthew Cong <[email protected]>
1 parent 72406ff commit dd7f12e

File tree

3 files changed

+23
-45
lines changed

3 files changed

+23
-45
lines changed

nanovdb/nanovdb/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ set(NANOVDB_INCLUDE_FILES
172172
# NanoVDB cuda header files
173173
set(NANOVDB_INCLUDE_CUDA_FILES
174174
cuda/DeviceBuffer.h
175-
cuda/DeviceGuard.h
176175
cuda/DeviceMesh.h
177176
cuda/DeviceStreamMap.h
178177
cuda/GridHandle.cuh

nanovdb/nanovdb/cuda/DeviceGuard.h

Lines changed: 0 additions & 41 deletions
This file was deleted.

nanovdb/nanovdb/cuda/DeviceMesh.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include <algorithm>
1616

1717
#include <nanovdb/util/cuda/Util.h>
18-
#include <nanovdb/cuda/DeviceGuard.h>
1918
#ifdef NANOVDB_USE_NCCL
2019
#include <nccl.h>
2120
#endif
@@ -24,6 +23,27 @@ namespace nanovdb {
2423

2524
namespace cuda {
2625

26+
namespace detail {
27+
28+
/// @brief RAII class that caches/restores the current device at construction/destruction
29+
class DeviceGuard {
30+
public:
31+
DeviceGuard() { cudaGetDevice(&deviceId); }
32+
~DeviceGuard() { cudaSetDevice(deviceId); }
33+
34+
/// @{
35+
/// @brief DeviceGuard is not copyable nor movable
36+
DeviceGuard(const DeviceGuard&) = delete;
37+
DeviceGuard& operator=(const DeviceGuard&) = delete;
38+
DeviceGuard(DeviceGuard&& other) = delete;
39+
DeviceGuard& operator=(DeviceGuard&& other) = delete;
40+
/// @}
41+
private:
42+
int deviceId = -1;
43+
};
44+
45+
}
46+
2747
/// @brief POD struct representing a device id and a stream on that device
2848
struct DeviceNode
2949
{
@@ -90,7 +110,7 @@ class DeviceMesh
90110

91111
inline DeviceMesh::DeviceMesh()
92112
{
93-
DeviceGuard deviceGuard;
113+
detail::DeviceGuard deviceGuard;
94114

95115
int deviceCount = -1;
96116
cudaGetDeviceCount(&deviceCount);
@@ -135,7 +155,7 @@ inline DeviceMesh::DeviceMesh(DeviceMesh&& other) noexcept
135155

136156
inline DeviceMesh::~DeviceMesh()
137157
{
138-
DeviceGuard deviceGuard;
158+
detail::DeviceGuard deviceGuard;
139159

140160
#ifdef NANOVDB_USE_NCCL
141161
std::for_each(mComms.begin(), mComms.end(), [](ncclComm_t comm) {

0 commit comments

Comments
 (0)