Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ bool HostAccessibleDeviceAllocator::isSupported()
return true;
}

std::string HostAccessibleDeviceAllocator::getUnSupportedReason()
{
return TopologyDetector::getInstance().getNoCurrentGpuMemoryNumaIdReason();
}

void HostAccessibleDeviceAllocator::init()
{
TLLM_CHECK(mIsInited == false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class HostAccessibleDeviceAllocator
*/
static bool isSupported();

static std::string getUnSupportedReason();

/**
* @brief Allocate host accessible memory on the device.
*
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ MoeLoadBalancer::MoeLoadBalancer(int epRank, int epSize, int layerUpdatesPerIter
int numaGpuCount = topologyDetector.getGpuCountUnderNuma(currentGpuNumaId);
HostAccessibleDeviceAllocator::getInstance().IncRefCount();
TLLM_CHECK_WITH_INFO(layerUpdatesPerIter == 0 || HostAccessibleDeviceAllocator::getInstance().isSupported(),
"HostAccessibleDeviceAllocator is not supported on current platform, please install gdrcopy(gdrdrv).");
"HostAccessibleDeviceAllocator is not supported on current platform, please install gdrcopy(gdrdrv). %s",
HostAccessibleDeviceAllocator::getInstance().getUnSupportedReason().c_str());
TLLM_CHECK_WITH_INFO(
numaCpuCount > 0 && numaGpuCount > 0, "numaCpuCount=%d, numaGpuCount=%d", numaCpuCount, numaGpuCount);
int cpuCountPerGpu = std::max(1, numaCpuCount / numaGpuCount);
Expand Down
44 changes: 44 additions & 0 deletions cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ void TopologyDetector::detectCpuTopology()
mCpuArchitecture = "unknown";
#endif

mDebugStringStream << "CPU Architecture: " << mCpuArchitecture << "\n";

// Detect NUMA topology on Linux systems using libnuma
#ifdef __linux__
if (numa_available() == -1)
Expand All @@ -106,13 +108,17 @@ void TopologyDetector::detectCpuTopology()
// Failed to get max node, fall back to default behavior
TLLM_LOG_WARNING("Failed to get max NUMA node. Falling back to default CPU topology detection.");
mNumaToCpuCountMap[0] = std::thread::hardware_concurrency();
mDebugStringStream << "Failed to get max NUMA node. Falling back to default CPU topology detection.\n";
return;
}

mDebugStringStream << "Max NUMA node: " << maxNode << "\n";

mNumaToCpuCountMap.clear(); // Clear before re-populating
std::map<int, int> tempNumaToCpuCountMap;
for (int i = 0; i <= maxNode; ++i)
{
mDebugStringStream << "Querying NUMA node " << i << "\n";
struct bitmask* cpus = numa_allocate_cpumask();
if (!cpus)
{
Expand All @@ -129,6 +135,7 @@ void TopologyDetector::detectCpuTopology()
if (numa_bitmask_isbitset(cpus, cpu_idx))
{
cpuCount++;
mDebugStringStream << "CPU " << cpu_idx << " is on NUMA node " << i << "\n";
}
}
if (cpuCount > 0)
Expand All @@ -140,6 +147,11 @@ void TopologyDetector::detectCpuTopology()
// In this case, we simply don't add it to our map, effectively skipping it.

numa_free_cpumask(cpus); // Always free the allocated mask

// here detect the memory size of current NUMA node

auto memorySize = numa_node_size64(i, NULL);
mDebugStringStream << "Memory size of NUMA node " << i << " is " << memorySize << "\n";
}
mNumaToCpuCountMap = tempNumaToCpuCountMap;

Expand Down Expand Up @@ -218,9 +230,16 @@ void TopologyDetector::detectGpuTopology()
}
int hasMemoryNumaConfig = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&hasMemoryNumaConfig, cudaDevAttrNumaConfig, deviceId));
mDebugStringStream << "[Init] GPU[" << deviceId << "] hasMemoryNumaConfig=" << hasMemoryNumaConfig << ", ";
if (hasMemoryNumaConfig == cudaDeviceNumaConfigNumaNode)
{
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&numaMemoryNode, cudaDevAttrNumaId, deviceId));
mDebugStringStream << "numaMemoryNode=" << numaMemoryNode << "\n";
}
else
{
mDebugStringStream << "numaMemoryNode=-1"
<< "\n";
}
#endif

Expand Down Expand Up @@ -427,6 +446,31 @@ int TopologyDetector::getCurrentGpuMemoryNumaId()
return -1;
}

std::string TopologyDetector::getNoCurrentGpuMemoryNumaIdReason()
{
int currentDevice = -1;
TLLM_CUDA_CHECK(cudaGetDevice(&currentDevice));
std::string reason;
reason += mDebugStringStream.str();
reason += "Current GPU=" + std::to_string(currentDevice) + ", mGpuMemoryToNumaMap={";
for (auto it = mGpuMemoryToNumaMap.begin(); it != mGpuMemoryToNumaMap.end(); ++it)
{
reason += "GPU[" + std::to_string(it->first) + "] memory NUMA Node=" + std::to_string(it->second) + ", ";
}
reason += "}";
auto itGpuToNuma = mGpuMemoryToNumaMap.find(currentDevice);
if (itGpuToNuma != mGpuMemoryToNumaMap.end())
{
reason += ", FOUND GPU[" + std::to_string(itGpuToNuma->first)
+ "] memory NUMA Node=" + std::to_string(itGpuToNuma->second) + ", ";
}
else
{
reason += ", NOT FOUND GPU[" + std::to_string(currentDevice) + "] memory NUMA Node";
}
return reason;
}

int TopologyDetector::getGpuCountUnderNuma(int numaId)
{
auto it = mNumaToGpuMap.find(numaId);
Expand Down
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <map>
#include <mutex>
#include <sstream>
#include <string>
#include <vector>
#ifdef __linux__
Expand Down Expand Up @@ -60,6 +61,8 @@ class TopologyDetector
// Returns -1 if it doesn't have NUMA ID.
int getCurrentGpuMemoryNumaId();

std::string getNoCurrentGpuMemoryNumaIdReason();

// Returns the number of GPUs associated with the given NUMA node ID.
int getGpuCountUnderNuma(int numaId);

Expand Down Expand Up @@ -104,6 +107,8 @@ class TopologyDetector
// Precomputed CPU affinity masks
std::map<int, struct bitmask*> mGpuStrictCpuMasks; // GPU ID -> Strict CPU mask
#endif

std::stringstream mDebugStringStream;
};

} // namespace tensorrt_llm::runtime
4 changes: 2 additions & 2 deletions tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ l0_rtx_pro_6000:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] # need gdrdrv on non Grace(C2C) platform
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] # need gdrdrv on non Grace(C2C) platform
# - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] # hopper only
# - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbug
accuracy/test_llm_api.py::TestPhi4MiniInstruct::test_fp8 SKIP (https://nvbugs/5465143)
accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype SKIP (https://nvbugs/5481090)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=False] SKIP (https://nvbugs/5483534)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5444687)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] SKIP (https://nvbugs/5444687)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugs/5488118)
test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B] SKIP (https://nvbugs/5448523)
cpp/test_unit_tests.py::test_unit_tests[kernels-80] SKIP (https://nvbugs/5504078)
Expand Down Expand Up @@ -332,7 +330,6 @@ full:L20/accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantize
full:L20/accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5542862)
full:L40S/accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized SKIP (https://nvbugs/5542862)
full:L40S/accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8 SKIP (https://nvbugs/5542862)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] SKIP (https://nvbugs/5543035)
unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3 SKIP (https://nvbugs/5536131)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5541494)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto] SKIP (https://nvbugs/5541494)
Expand Down